diff --git a/app.py b/app.py index 4bb8f27e78f37fb2f54d8e2299b82ca7ea95b13e..d841a5f0d7f31c6a41c9f95f052fff992509fd42 100644 --- a/app.py +++ b/app.py @@ -1,11 +1,8 @@ import os -import sys from typing import List, Tuple PWD = os.path.dirname(__file__) -sys.path.append(os.path.join(PWD, "cosmos-transfer1")) - import subprocess subprocess.run("pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True) @@ -26,23 +23,10 @@ except Exception as e: print(f"Authentication failed: {e}") # download checkpoints -subprocess.run( - [ - "python", - os.path.join(PWD, "cosmos-transfer1", "scripts", "download_checkpoints.py"), - "--output_dir", - os.path.join(PWD, "cosmos-transfer1", "checkpoints"), - "--model", - "7b_av", - ], - shell=True, -) -# subprocess.run( -# f"python cosmos-transfer1/scripts/download_checkpoints.py \ -# --hf_token {hf_token} \ -# --output_dir cosmos-transfer1/checkpoints/ \ -# --model 7b_av" -# ) +from download_checkpoints import main as download_checkpoints + +os.makedirs("./checkpoints", exist_ok=True) +download_checkpoints(hf_token="", output_dir="./checkpoints", model="7b_av") os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning diff --git a/cosmos-transfer1/.flake8 b/cosmos-transfer1/.flake8 deleted file mode 100644 index ae19ba1d115ecf5fd4f007490451a0e40c2abc5c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/.flake8 +++ /dev/null @@ -1,10 +0,0 @@ -[flake8] -enable-extensions = G -select = B,C,E,F,G,P,SIM1,T4,W,B9 -max-line-length = 120 -# C408 ignored because we like the dict keyword argument syntax -# E501 is not flexible enough, we're using B950 instead -ignore = - E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,E226,E265 -exclude = - third_party diff --git a/cosmos-transfer1/.gitignore b/cosmos-transfer1/.gitignore deleted file mode 100644 index 3776f75385467dfa87681b4378794498b37410f8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/.gitignore +++ /dev/null @@ -1,242 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Misc -outputs/ -checkpoints/* -!checkpoints/README.md - -# Data types -*.jit -*.pt -*.hdr -*.webp -*.pgm -*.tiff -*.tif -*.tar -*.tar.gz -*.gz -*.pkl -*.pt -*.bin - -# Other uncheckable file types -*.zip -*.exe -*.dll -*.swp -*.vscode -*.DS_Store -*.pyc -*Thumbs.db -*.patch - -# Credential information that should never be checked in -credentials -*.secret - -# ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------ - -# Byte-compiled / optimized / DLL files -**/__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -results/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.config -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Third party -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# ruff -.ruff_cache - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ -CLIP -.devcontainer/devcontainer.json - -# Coverage -.coverage -coverage.xml - -# JUnit Reports -report.xml - -# CI-CD -temp/ -envs.txt -manifest.json - - -# locks and t5 temp files -*.locks* -*.no_exist* -*models--t5* - -# OneLogger -wandb/ -onelogger.err -onelogger.log diff --git a/cosmos-transfer1/.pre-commit-config.yaml b/cosmos-transfer1/.pre-commit-config.yaml deleted file mode 100644 index d82ca7ba148ca42f4adaf78648b672d73beee30a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/.pre-commit-config.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -default_language_version: - python: python3.10 -repos: - - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - args: - - --max-line-length=120 - - --ignore=E501,F401,E203,E402,E265,E741,F841,F821,F811,W503,E231,E225,E702 - exclude: ^dist/|^third_party/ - - - repo: https://github.com/psf/black - rev: 23.12.1 - hooks: - - id: black - args: [--line-length=120] - exclude: ^dist/|^third_party/ - - - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 - hooks: - - id: isort - args: [--line-length=120] - - - repo: https://github.com/MarcoGorelli/absolufy-imports - rev: v0.3.1 - hooks: - - id: absolufy-imports - - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 - hooks: - - id: trailing-whitespace - exclude: ^tests/.*/fixtures/.* - args: [--markdown-linebreak-ext=md] - - id: end-of-file-fixer - exclude: ^tests/.*/fixtures/.* - - id: check-added-large-files - args: ['--maxkb=2000'] diff --git a/cosmos-transfer1/ATTRIBUTIONS.md b/cosmos-transfer1/ATTRIBUTIONS.md deleted file mode 100644 index 669587070c292fb9e9ffff4ae7b383bc52677b3b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/ATTRIBUTIONS.md +++ /dev/null @@ -1,1661 +0,0 @@ -# Open Source License Attribution - - Cosmos uses Open Source components. You can find the details of these open-source projects along with license information below, sorted alphabetically. - We are grateful to the developers for their contributions to open source and acknowledge these below. - -## Better-Profanity - [MIT License](https://github.com/snguyenthanh/better_profanity/blob/master/LICENSE) - - ``` - - Copyright (c) 2018 The Python Packaging Authority - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - - ``` - -## FFmpeg - [FFMPEG License](https://github.com/FFmpeg/FFmpeg/blob/master/LICENSE.md) - - ``` - # License - - Most files in FFmpeg are under the GNU Lesser General Public License version 2.1 - or later (LGPL v2.1+). Read the file `COPYING.LGPLv2.1` for details. Some other - files have MIT/X11/BSD-style licenses. In combination the LGPL v2.1+ applies to - FFmpeg. - - Some optional parts of FFmpeg are licensed under the GNU General Public License - version 2 or later (GPL v2+). See the file `COPYING.GPLv2` for details. None of - these parts are used by default, you have to explicitly pass `--enable-gpl` to - configure to activate them. In this case, FFmpeg's license changes to GPL v2+. - - Specifically, the GPL parts of FFmpeg are: - - - libpostproc - - optional x86 optimization in the files - - `libavcodec/x86/flac_dsp_gpl.asm` - - `libavcodec/x86/idct_mmx.c` - - `libavfilter/x86/vf_removegrain.asm` - - the following building and testing tools - - `compat/solaris/make_sunver.pl` - - `doc/t2h.pm` - - `doc/texi2pod.pl` - - `libswresample/tests/swresample.c` - - `tests/checkasm/*` - - `tests/tiny_ssim.c` - - the following filters in libavfilter: - - `signature_lookup.c` - - `vf_blackframe.c` - - `vf_boxblur.c` - - `vf_colormatrix.c` - - `vf_cover_rect.c` - - `vf_cropdetect.c` - - `vf_delogo.c` - - `vf_eq.c` - - `vf_find_rect.c` - - `vf_fspp.c` - - `vf_histeq.c` - - `vf_hqdn3d.c` - - `vf_kerndeint.c` - - `vf_lensfun.c` (GPL version 3 or later) - - `vf_mcdeint.c` - - `vf_mpdecimate.c` - - `vf_nnedi.c` - - `vf_owdenoise.c` - - `vf_perspective.c` - - `vf_phase.c` - - `vf_pp.c` - - `vf_pp7.c` - - `vf_pullup.c` - - `vf_repeatfields.c` - - `vf_sab.c` - - `vf_signature.c` - - `vf_smartblur.c` - - `vf_spp.c` - - `vf_stereo3d.c` - - `vf_super2xsai.c` - - `vf_tinterlace.c` - - `vf_uspp.c` - - `vf_vaguedenoiser.c` - - `vsrc_mptestsrc.c` - - Should you, for whatever reason, prefer to use version 3 of the (L)GPL, then - the configure parameter `--enable-version3` will activate this licensing option - for you. Read the file `COPYING.LGPLv3` or, if you have enabled GPL parts, - `COPYING.GPLv3` to learn the exact legal terms that apply in this case. - - There are a handful of files under other licensing terms, namely: - - * The files `libavcodec/jfdctfst.c`, `libavcodec/jfdctint_template.c` and - `libavcodec/jrevdct.c` are taken from libjpeg, see the top of the files for - licensing details. Specifically note that you must credit the IJG in the - documentation accompanying your program if you only distribute executables. - You must also indicate any changes including additions and deletions to - those three files in the documentation. - * `tests/reference.pnm` is under the expat license. - - - ## External libraries - - FFmpeg can be combined with a number of external libraries, which sometimes - affect the licensing of binaries resulting from the combination. - - ### Compatible libraries - - The following libraries are under GPL version 2: - - avisynth - - frei0r - - libcdio - - libdavs2 - - librubberband - - libvidstab - - libx264 - - libx265 - - libxavs - - libxavs2 - - libxvid - - When combining them with FFmpeg, FFmpeg needs to be licensed as GPL as well by - passing `--enable-gpl` to configure. - - The following libraries are under LGPL version 3: - - gmp - - libaribb24 - - liblensfun - - When combining them with FFmpeg, use the configure option `--enable-version3` to - upgrade FFmpeg to the LGPL v3. - - The VMAF, mbedTLS, RK MPI, OpenCORE and VisualOn libraries are under the Apache License - 2.0. That license is incompatible with the LGPL v2.1 and the GPL v2, but not with - version 3 of those licenses. So to combine these libraries with FFmpeg, the - license version needs to be upgraded by passing `--enable-version3` to configure. - - The smbclient library is under the GPL v3, to combine it with FFmpeg, - the options `--enable-gpl` and `--enable-version3` have to be passed to - configure to upgrade FFmpeg to the GPL v3. - - ### Incompatible libraries - - There are certain libraries you can combine with FFmpeg whose licenses are not - compatible with the GPL and/or the LGPL. If you wish to enable these - libraries, even in circumstances that their license may be incompatible, pass - `--enable-nonfree` to configure. This will cause the resulting binary to be - unredistributable. - - The Fraunhofer FDK AAC and OpenSSL libraries are under licenses which are - incompatible with the GPLv2 and v3. To the best of our knowledge, they are - compatible with the LGPL. - - ``` - -## Hydra-core [MIT License](https://github.com/facebookresearch/hydra/blob/main/LICENSE) - - ``` - - MIT License - - Copyright (c) Facebook, Inc. and its affiliates. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - - ``` - -## ImageIo - [BSD 2-Clause "Simplified" License](https://github.com/imageio/imageio/blob/master/LICENSE) - - ``` - - Copyright (c) 2014-2022, imageio developers - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ``` - -## Iopath - [MIT License](https://github.com/facebookresearch/iopath/blob/main/LICENSE) - - ``` - MIT License - - Copyright (c) Facebook, Inc. and its affiliates. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - - ``` - -## Llama-Guard-3-8B [META LLAMA 3 COMMUNITY LICENSE](https://github.com/meta-llama/llama3/blob/main/LICENSE) - - ``` - - META LLAMA 3 COMMUNITY LICENSE AGREEMENT - - Meta Llama 3 Version Release Date: April 18, 2024 - - “Agreement” means the terms and conditions for use, reproduction, distribution, and - modification of the Llama Materials set forth herein. - - “Documentation” means the specifications, manuals, and documentation accompanying Meta - Llama 3 distributed by Meta at https://llama.meta.com/get-started/. - - “Licensee” or “you” means you, or your employer or any other person or entity (if you are - entering into this Agreement on such person or entity’s behalf), of the age required under - applicable laws, rules, or regulations to provide legal consent and that has legal authority - to bind your employer or such other person or entity if you are entering into this Agreement - on their behalf. - - “Meta Llama 3” means the foundational large language models and software and algorithms, - including machine-learning model code, trained model weights, inference-enabling code, - training-enabling code, fine-tuning-enabling code, and other elements of the foregoing - distributed by Meta at https://llama.meta.com/llama-downloads. - - “Llama Materials” means, collectively, Meta’s proprietary Meta Llama 3 and Documentation - (and any portion thereof) made available under this Agreement. - - “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are - an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, - Inc. (if you are located outside of the EEA or Switzerland). - - By clicking “I Accept” below or by using or distributing any portion or element of the Llama - Materials, you agree to be bound by this Agreement. - - 1. License Rights and Redistribution. - - a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and - royalty-free limited license under Meta’s intellectual property or other rights owned by - Meta embodied in the Llama Materials to use, reproduce, distribute, copy, create derivative - works of, and make modifications to the Llama Materials. - - b. Redistribution and Use. - i. If you distribute or make available the Llama Materials (or any derivative works - thereof), or a product or service that uses any of them, including another AI model, you - shall (A) provide a copy of this Agreement with any such Llama Materials; and (B) - prominently display “Built with Meta Llama 3” on a related website, user interface, - blogpost, about page, or product documentation. If you use the Llama Materials to create, - train, fine tune, or otherwise improve an AI model, which is distributed or made available, - you shall also include “Llama 3” at the beginning of any such AI model name. - - ii. If you receive Llama Materials, or any derivative works thereof, from a Licensee as - part of an integrated end user product, then Section 2 of this Agreement will not apply - to you. - - iii. You must retain in all copies of the Llama Materials that you distribute the - following attribution notice within a “Notice” text file distributed as a part of such - copies: “Meta Llama 3 is licensed under the Meta Llama 3 Community License, Copyright © - Meta Platforms, Inc. All Rights Reserved.” - - iv. Your use of the Llama Materials must comply with applicable laws and regulations - (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy - for the Llama Materials (available at https://llama.meta.com/llama3/use-policy), which - is hereby incorporated by reference into this Agreement. - - v. You will not use the Llama Materials or any output or results of the Llama Materials - to improve any other large language model (excluding Meta Llama 3 or derivative works - thereof). - - 2. Additional Commercial Terms. - - If, on the Meta Llama 3 version release date, the monthly active users of the products or - services made available by or for Licensee, or Licensee’s affiliates, is greater than 700 - million monthly active users in the preceding calendar month, you must request a license - from Meta, which Meta may grant to you in its sole discretion, and you are not authorized - to exercise any of the rights under this Agreement unless or until Meta otherwise expressly - grants you such rights. - - 3. Disclaimer of Warranty. - - UNLESS REQUIRED BY APPLICABLE LAW, THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM - ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL - WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY - WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING - THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE LLAMA MATERIALS - AND ANY OUTPUT AND RESULTS. - - 4. Limitation of Liability. - - IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN - CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, - FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR - PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY - OF THE FOREGOING. - - 5. Intellectual Property. - - a. No trademark licenses are granted under this Agreement, and in connection with the Llama - Materials, neither Meta nor Licensee may use any name or mark owned by or associated with - the other or any of its affiliates, except as required for reasonable and customary use in - describing and redistributing the Llama Materials or as set forth in this Section 5(a). - Meta hereby grants you a license to use “Llama 3” (the “Mark”) solely as required to comply - with the last sentence of Section 1.b.i. You will comply with Meta’s brand guidelines - (currently accessible at https://about.meta.com/brand/resources/meta/company-brand/). - All goodwill arising out of your use of the Mark will inure to the benefit of Meta. - - b. Subject to Meta’s ownership of Llama Materials and derivatives made by or for Meta, with - respect to any derivative works and modifications of the Llama Materials that are made by - you, as between you and Meta, you are and will be the owner of such derivative works and - modifications. - - c. If you institute litigation or other proceedings against Meta or any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Llama Materials or Meta Llama 3 - outputs or results, or any portion of any of the foregoing, constitutes infringement of - intellectual property or other rights owned or licensable by you, then any licenses granted - to you under this Agreement shall terminate as of the date such litigation or claim is filed - or instituted. You will indemnify and hold harmless Meta from and against any claim by any - third party arising out of or related to your use or distribution of the Llama Materials. - - 6. Term and Termination. - - The term of this Agreement will commence upon your acceptance of this Agreement or access - to the Llama Materials and will continue in full force and effect until terminated in - accordance with the terms and conditions herein. Meta may terminate this Agreement if you - are in breach of any term or condition of this Agreement. Upon termination of this Agreement, - you shall delete and cease use of the Llama Materials. Sections 3, 4, and 7 shall survive - the termination of this Agreement. - - 7. Governing Law and Jurisdiction. - - This Agreement will be governed and construed under the laws of the State of California - without regard to choice of law principles, and the UN Convention on Contracts for the - International Sale of Goods does not apply to this Agreement. The courts of California - shall have exclusive jurisdiction of any dispute arising out of this Agreement. - - META LLAMA 3 ACCEPTABLE USE POLICY - - Meta is committed to promoting safe and fair use of its tools and features, including Meta - Llama 3. If you access or use Meta Llama 3, you agree to this Acceptable Use Policy - (“Policy”). The most recent copy of this policy can be found at - https://llama.meta.com/llama3/use-policy. - - Prohibited Uses - - We want everyone to use Meta Llama 3 safely and responsibly. You agree you will not use, or - allow others to use, Meta Llama 3 to: - - 1. Violate the law or others’ rights, including to: - - a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal - or unlawful activity or content, such as: - - i. Violence or terrorism - ii. Exploitation or harm to children, including the solicitation, creation, acquisition, - or dissemination of child exploitative content or failure to report Child Sexual Abuse - Material - iii. Human trafficking, exploitation, and sexual violence - iv. The illegal distribution of information or materials to minors, including obscene - materials, or failure to employ legally required age-gating in connection with such - information or materials - v. Sexual solicitation - vi. Any other criminal activity - - b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or - bullying of individuals or groups of individuals - - c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful - conduct in the provision of employment, employment benefits, credit, housing, other economic - benefits, or other essential goods and services - - d. Engage in the unauthorized or unlicensed practice of any profession including, but not - limited to, financial, legal, medical/health, or related professional practices - - e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive - personal or private information about individuals without rights and consents required by - applicable laws - - f. Engage in or facilitate any action or generate any content that infringes, misappropriates, - or otherwise violates any third-party rights, including the outputs or results of any - products or services using the Llama Materials - - g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses - or do anything else that could disable, overburden, interfere with or impair the proper - working, integrity, operation, or appearance of a website or computer system - - 2. Engage in, promote, incite, facilitate, or assist in the planning or development of - activities that present a risk of death or bodily harm to individuals, including use of Meta - Llama 3 related to the following: - - a. Military, warfare, nuclear industries or applications, espionage, use for materials or - activities that are subject to the International Traffic Arms Regulations (ITAR) maintained - by the United States Department of State - b. Guns and illegal weapons (including weapon development) - c. Illegal drugs and regulated/controlled substances - d. Operation of critical infrastructure, transportation technologies, or heavy machinery - e. Self-harm or harm to others, including suicide, cutting, and eating disorders - f. Any content intended to incite or promote violence, abuse, or any infliction of bodily - harm to an individual - - 3. Intentionally deceive or mislead others, including use of Meta Llama 3 related to the - following: - - a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation - b. Generating, promoting, or furthering defamatory content, including the creation of - defamatory statements, images, or other content - c. Generating, promoting, or further distributing spam - d. Impersonating another individual without consent, authorization, or legal right - e. Representing that the use of Meta Llama 3 or outputs are human-generated - f. Generating or facilitating false online engagement, including fake reviews and other - means of fake online engagement - g. Fail to appropriately disclose to end users any known dangers of your AI system - - Please report any violation of this Policy, software “bug,” or other problems that could - lead to a violation of this Policy through one of the following means: - - * Reporting issues with the model: https://github.com/meta-llama/llama3 - * Reporting risky content generated by the model: developers.facebook.com/llama_output_feedback - * Reporting bugs and security concerns: facebook.com/whitehat/info - * Reporting violations of the Acceptable Use Policy or unlicensed uses of Meta Llama 3: - LlamaUseReport@meta.com - - ``` - -## Loguru - [MIT License](https://github.com/Delgan/loguru/blob/master/LICENSE) - - ``` - - MIT License - - Copyright (c) 2017 - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - - ``` - -## Mediapy - [Apache License 2.0](https://github.com/google/mediapy/blob/main/LICENSE) - - ``` - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ``` - -## Nltk - [Apache License 2.0](https://github.com/nltk/nltk/blob/develop/LICENSE.txt) - - ``` - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ``` - -## PEFT - [Apache License 2.0](https://github.com/huggingface/peft/blob/main/LICENSE) - - ``` - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ``` - -## Pillow - [MIT License](https://github.com/python-pillow/Pillow/blob/main/LICENSE) - - ``` - - The Python Imaging Library (PIL) is - - Copyright © 1997-2011 by Secret Labs AB - Copyright © 1995-2011 by Fredrik Lundh and contributors - - Pillow is the friendly PIL fork. It is - - Copyright © 2010 by Jeffrey A. Clark and contributors - - Like PIL, Pillow is licensed under the open source MIT-CMU License: - - By obtaining, using, and/or copying this software and/or its associated - documentation, you agree that you have read, understood, and will comply - with the following terms and conditions: - - Permission to use, copy, modify and distribute this software and its - documentation for any purpose and without fee is hereby granted, - provided that the above copyright notice appears in all copies, and that - both that copyright notice and this permission notice appear in supporting - documentation, and that the name of Secret Labs AB or the author not be - used in advertising or publicity pertaining to distribution of the software - without specific, written prior permission. - - SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS - SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. - IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, - INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM - LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE - OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR - PERFORMANCE OF THIS SOFTWARE. - - ``` - -## PyAV - [BSD 3-Clause "New" or "Revised" License](https://github.com/PyAV-Org/PyAV/blob/main/LICENSE.txt) - - ``` - - Copyright retained by original committers. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of the project nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY DIRECT, - INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY - OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, - EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ``` - -## Pytorch_Retinaface - [MIT License](https://github.com/biubug6/Pytorch_Retinaface/blob/master/LICENSE.MIT) - - ``` - MIT License - - Copyright (c) 2019 - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - ``` - -## Sentencepiece - [Apache License 2.0](https://github.com/google/sentencepiece/blob/master/LICENSE) - - ``` - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ``` - -## Termcolor - [MIT License](https://github.com/termcolor/termcolor/blob/main/COPYING.txt) - - ``` - Copyright (c) 2008-2011 Volvox Development Team - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - THE SOFTWARE. - ``` - -## Transformers [Apache License 2.0](https://github.com/huggingface/transformers/blob/main/LICENSE) - - ``` - - Copyright 2018- The Hugging Face team. All rights reserved. - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ``` diff --git a/cosmos-transfer1/CONTRIBUTING.md b/cosmos-transfer1/CONTRIBUTING.md deleted file mode 100644 index 35d4e592856b67b5ccf4bf9253d4476635d8d797..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/CONTRIBUTING.md +++ /dev/null @@ -1,51 +0,0 @@ -# How to Contribute - -We'd love to receive your patches and contributions. Please keep your PRs as draft until such time that you would like us to review them. - -## Code Reviews - -All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. - -## Signing Your Work - -* We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - - * Any contribution which contains commits that are not Signed-Off will not be accepted. - -* To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes: - ```bash - $ git commit -s -m "Add cool feature." - ``` - This will append the following to your commit message: - ``` - Signed-off-by: Your Name - ``` - -* Full text of the DCO: - - ``` - Developer Certificate of Origin - Version 1.1 - - Copyright (C) 2004, 2006 The Linux Foundation and its contributors. - 1 Letterman Drive - Suite D4700 - San Francisco, CA, 94129 - - Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. - ``` - - ``` - Developer's Certificate of Origin 1.1 - - By making a contribution to this project, I certify that: - - (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or - - (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or - - (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. - - (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. - ``` diff --git a/cosmos-transfer1/Dockerfile b/cosmos-transfer1/Dockerfile deleted file mode 100644 index 9996abb0940722a1052eac71e57205813ee5b476..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/Dockerfile +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Use NVIDIA PyTorch container as base image -FROM nvcr.io/nvidia/tritonserver:25.04-vllm-python-py3 - -# Install basic tools -RUN apt-get update && apt-get install -y git tree ffmpeg wget -RUN rm /bin/sh && ln -s /bin/bash /bin/sh && ln -s /lib64/libcuda.so.1 /lib64/libcuda.so -RUN apt-get install -y libglib2.0-0 - -# Copy the cosmos-transfer1.yaml and requirements.txt files to the container -COPY ./cosmos-transfer1.yaml /cosmos-transfer1.yaml -COPY ./requirements_docker.txt /requirements.txt - -RUN ls -l /usr/lib/python3/dist-packages/blinker-1.7.0.dist-info && rm -rf /usr/lib/python3/dist-packages/blinker-1.7.0.dist-info -RUN echo "Installing dependencies. This will take a while..." && \ - pip install --no-cache-dir -r /requirements.txt && \ - pip install -v --upgrade --no-build-isolation --no-dependencies sam2==1.1.0 && \ - pip install transformer-engine[pytorch] && \ - pip install decord==0.6.0 && \ - git clone https://github.com/NVIDIA/apex && \ - pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex && \ - rm -rf apex && \ - pip install -v decord==0.6.0 && \ - echo "Environment setup complete" - -# Create Python symlink -RUN ln -s /usr/bin/python3.12 /usr/bin/python -RUN apt-get install -y libmagic1 - -RUN mkdir -p /workspace -WORKDIR /workspace - -CMD ["/bin/bash"] diff --git a/cosmos-transfer1/INSTALL.md b/cosmos-transfer1/INSTALL.md deleted file mode 100644 index 404b207500b9be7c18da78e4c774c7c84423c31b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/INSTALL.md +++ /dev/null @@ -1,88 +0,0 @@ -## Environment setup - -Clone the `cosmos-transfer1` source code -```bash -git clone git@github.com:nvidia-cosmos/cosmos-transfer1.git -cd cosmos-transfer1 -git submodule update --init --recursive -``` - -Cosmos runs only on Linux systems. We have tested the installation with Ubuntu 24.04, 22.04, and 20.04. -Cosmos requires the Python version to be `3.12.x`. - -### Inference using conda - -Please also make sure you have `conda` installed ([instructions](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)). - -The below commands create the `cosmos-transfer1` conda environment and install the dependencies for inference: -```bash -# Create the cosmos-transfer1 conda environment. -conda env create --file cosmos-transfer1.yaml -# Activate the cosmos-transfer1 conda environment. -conda activate cosmos-transfer1 -# Install the dependencies. -pip install -r requirements.txt -# Install vllm -pip install https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl -export VLLM_ATTENTION_BACKEND=FLASHINFER -pip install vllm==0.9.0 -# Install decord -pip install decord==0.6.0 -# Patch Transformer engine linking issues in conda environments. -ln -sf $CONDA_PREFIX/lib/python3.12/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/ -ln -sf $CONDA_PREFIX/lib/python3.12/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.12 -# Install Transformer engine. -pip install transformer-engine[pytorch] -``` - -To test the environment setup for inference run -```bash -PYTHONPATH=$(pwd) python scripts/test_environment.py -``` - -### Inference using docker - -If you prefer to use a containerized environment, you can build and run this repo's dockerfile to get an environment with all the packages pre-installed. This environment does not use conda. So, there is no need to specify `CUDA_HOME=$CONDA_PREFIX` when invoking this repo's scripts. - -This requires docker to be already present on your system with the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) installed. - -```bash -docker build -f Dockerfile . -t nvcr.io/$USER/cosmos-transfer1:latest -``` - -Note: In case you encounter permission issues while mounting local files inside the docker, you can share the folders from your current directory to all users (including docker) using this helpful alias -``` -alias share='sudo chown -R ${USER}:users $PWD && sudo chmod g+w $PWD' -``` -before running the docker. - -### Training - -The below commands creates the `cosmos-transfer` conda environment and installs the dependencies for training. This is the same as required for inference. -```bash -# Create the cosmos-transfer1 conda environment. -conda env create --file cosmos-transfer1.yaml -# Activate the cosmos-transfer1 conda environment. -conda activate cosmos-transfer1 -# Install the dependencies. -pip install -r requirements.txt -# Install vllm -pip install https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.5%2Bcu128torch2.7-cp38-abi3-linux_x86_64.whl -export VLLM_ATTENTION_BACKEND=FLASHINFER -pip install vllm==0.9.0 -# Install decord -pip install decord==0.6.0 -# Patch Transformer engine linking issues in conda environments. -ln -sf $CONDA_PREFIX/lib/python3.12/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/ -ln -sf $CONDA_PREFIX/lib/python3.12/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.12 -# Install Transformer engine. -pip install transformer-engine[pytorch] -# Install Apex for full training with bfloat16. -git clone https://github.com/NVIDIA/apex -pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex -``` - -You can test the environment setup for post-training with -```bash -PYTHONPATH=$(pwd) python scripts/test_environment.py --training -``` diff --git a/cosmos-transfer1/LICENSE b/cosmos-transfer1/LICENSE deleted file mode 100644 index 4c9ad980682246bd6ab0d2bae82232be6dbdcbd4..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/cosmos-transfer1/README.md b/cosmos-transfer1/README.md deleted file mode 100644 index 10885910ae824b10d34174cb4c84b08d672fd4bf..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/README.md +++ /dev/null @@ -1,102 +0,0 @@ -

- NVIDIA Cosmos Header -

- -### [Product Website](https://www.nvidia.com/en-us/ai/cosmos/) | [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e) | [Paper](https://arxiv.org/abs/2503.14492) | [Paper Website](https://research.nvidia.com/labs/dir/cosmos-transfer1/) - -Cosmos-Transfer1 is a key branch of Cosmos World Foundation Models (WFMs) specialized for multimodal controllable conditional world generation or world2world transfer. The three main branches of Cosmos WFMs are [cosmos-predict](https://github.com/nvidia-cosmos/cosmos-predict1), [cosmos-transfer](https://github.com/nvidia-cosmos/cosmos-transfer1), and [cosmos-reason](https://github.com/nvidia-cosmos/cosmos-reason1). We visualize the architecture of Cosmos-Transfer1 in the following figure. - -

- Cosmos-Transfer1 Architecture Diagram -

- - -Cosmos-Transfer1 includes the following: -- **ControlNet-based single modality conditional world generation** where a user can generate visual simulation based on one of the following modalities: segmentation video, depth video, edge video, blur video, LiDAR video, or HDMap video. Cosmos-Transfer1 generates a video based on the signal modality conditional input, a user text prompt, and, optionally, an input RGB video frame prompt (which could be from the last video generation result when operating in the autoregressive setting). We will use Cosmos-Transfer1-7B [Modality] to refer to the model operating in this setting. For example, Cosmos-Transfer1-7B [Depth] refers to a depth ControlNet model. -- **MultiControlNet-based multimodal conditional world generation** where a user can generate visual simulation based on any combination of segmentation video, depth video, edge video, and blur video (LiDAR video and HDMap in the AV sample) with a spatiotemporal control map to control the stregnth of each modality across space and time. Cosmos-Transfer1 generates a video based on the multimodal conditional inputs, a user text prompt, and, optionally, an input RGB video frame prompt (This could be from the last video generation result when operating in the autoregressive setting.). This is the preferred mode of Cosmos-Transfer. We will refer it as Cosmos-Transfer1-7B. -- **4KUpscaler** for upscaling a 720p-resolution video to a 4K-resolution video. -- **Post-training scripts** for helping Physical AI builders post-train pre-trained Cosmos-Transfer1 for their applications. -- **Pre-training scripts** for helping Physical AI builders train their own Cosmos-Transfer1 models from scratch. - -## News -- [2025/05] **Cosmos AV Single2MultiView** is available! Now you can create dynamic, multi-view clips from just one video. Try it out and tell us what you think! - - [Inference guide](examples/inference_cosmos_transfer1_7b_sample_av_single2multiview.md) - - [Build your own or PyTorch post-training](examples/training_cosmos_transfer_7B_sample_AV.md) - - - [Hugging Face model](https://huggingface.co/nvidia/Cosmos-Transfer1-7B-Sample-AV-Single2MultiView) - -- [2025/04] [Post training](README.md#post-train-pre-trained-cosmos-transfer1-models) is available! Now you can customize Transfer1 models in your own way. Please try it out and we look forward to your feedback. - -## Example Model Behavior - -[Cosmos-Transfer LiDAR + HDMap Conditional Inputs -> World](https://github.com/nvidia-cosmos/cosmos-transfer1) - - - -[Cosmos-Transfer Multimodal Conditional Inputs -> World](https://github.com/nvidia-cosmos/cosmos-transfer1) - - - -## Getting Started - -We provide a comphrehensive set of examples to illustrate how to perform inference, post-training, etc, with Cosmos-Transfer1. Click a relevant example below and start your Cosmos journey. - -### Installation - -Please refer to [INSTALL.md](INSTALL.md) for general instructions on environment setup. - -### Inference with pre-trained Cosmos-Transfer1 models - -* [Inference with pre-trained Cosmos-Transfer1-7B](/examples/inference_cosmos_transfer1_7b.md) **[with multi-GPU support]** -* [Inference with pre-trained Cosmos-Transfer1-7B-Sample-AV](/examples/inference_cosmos_transfer1_7b_sample_av.md) **[with multi-GPU support]** -* [Inference with pre-trained Cosmos-Transfer1-7B-4KUpscaler](/examples/inference_cosmos_transfer1_7b_4kupscaler.md) **[with multi-GPU support]** -* [Inference with pre-trained Cosmos-Transfer1-7B (Depth)](examples/inference_cosmos_transfer1_7b_depth.md) -* [Inference with pre-trained Cosmos-Transfer1-7B (Segmentation)](examples/inference_cosmos_transfer1_7b_seg.md) -* [Inference with pre-trained Cosmos-Transfer1-7B (Edge)](examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge) -* [Inference with pre-trained Cosmos-Transfer1-7B (Vis)](examples/inference_cosmos_transfer1_7b_vis.md) -* [Inference with pre-trained Cosmos-Transfer1pt1-7B [Keypoint]](/examples/inference_cosmos_transfer1pt1_7b_keypoint.md) -* [Inference with pre-trained Cosmos-Transfer1-7B-Sample-AV-Multiview](/examples/inference_cosmos_transfer1_7b_sample_av_single2multiview.md) - -### Post-train pre-trained Cosmos-Transfer1 models - -* [Post-train pre-trained Cosmos-Transfer1-7B [Depth | Edge | Keypoint | Segmentation | Vis]](examples/training_cosmos_transfer_7b.md) **[with multi-GPU support]** -* [Post-train pre-trained Cosmos-Transfer1-7B-Sample-AV [LiDAR|HDMap]](examples/training_cosmos_transfer_7B_sample_AV.md) **[with multi-GPU support]** -* [Post-train pre-trained Cosmos-Transfer1-7B-Sample-AV-Multiview[LiDAR|HDMap]](examples/training_cosmos_transfer_7B_sample_AV.md) **[with multi-GPU support]** - -### Build your own Cosmos-Transfer1 models from scratch - -* [Pre-train Cosmos-Transfer1-7B [Depth | Edge | Keypoint | Segmentation | Vis]](examples/training_cosmos_transfer_7b.md) **[with multi-GPU support]** -* [Pre-train Cosmos-Transfer1-7B-Sample-AV [LiDAR|HDMap]](examples/training_cosmos_transfer_7B_sample_AV.md) **[with multi-GPU support]** -* [Pre-train Cosmos-Transfer1-7B-Sample-AV-Multiview[LiDAR|HDMap]](examples/training_cosmos_transfer_7B_sample_AV.md) **[with multi-GPU support]** - -### Workflow - -*[Robotics Augmentation Workflow](/cosmos_transfer1/auxiliary/robot_augmentation/README.md): Scene augmentation for robotic manipulation, mapping one robotics synthetic example to multiple realistic examples - - - - -## Cosmos-Transfer1 Models - -* [Cosmos-Transfer1-7B](https://huggingface.co/nvidia/Cosmos-Transfer1-7B): multimodal controllable conditional world generation with adaptive spatiotemporal control map. The supported modalities include segmentation, depth, canny edge, and blur visual. -* [Cosmos-Transfer1-7B [Depth | Edge | Keypoint | Segmentation | Vis]](https://huggingface.co/nvidia/Cosmos-Transfer1-7B): single modality controllable conditional world generation. This refers to Cosmos-Transfer1-7B operates on the single modality case and is reduced to a ControlNet. -* [Cosmos-Transfer1-7B-Sample-AV](https://huggingface.co/nvidia/Cosmos-Transfer1-7B-Sample-AV): multimodal controllable conditional world generation with adaptive spatiotemporal control map specialized for autonomous vehicle applications. The supported modalities include LiDAR and HDMap. -* [Cosmos-Transfer1-7B [LiDAR | HDMap]](https://huggingface.co/nvidia/Cosmos-Transfer1-7B-Sample-AV): single modality controllable conditional world generation for autonomous vehicle applications. This refers to Cosmos-Transfer1-7B-Sample-AV operates on the single modality case and is reduced to a ControlNet. -* [Cosmos-Transfer1-7B-4KUpscaler](https://huggingface.co/nvidia/Cosmos-Transfer1-7B-4KUpscaler): 4K upscaler to super-resolute 720p videos to 4K videos. - - -## License and Contact - -This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use. - -This model includes safety and content moderation features powered by Llama Guard 3. Llama Guard 3 is used solely as a content input filter and is subject to its own license. - -NVIDIA Cosmos source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0). - -NVIDIA Cosmos models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please contact [cosmos-license@nvidia.com](mailto:cosmos-license@nvidia.com). diff --git a/cosmos-transfer1/checkpoints/README.md b/cosmos-transfer1/checkpoints/README.md deleted file mode 100644 index 80ddafebc8786e5ed5f6cc13a8adf4836f3a63a5..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/checkpoints/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Checkpoint directory - -Follow our instructions for downloading checkpoints in [Cosmos Diffusion Inference](../cosmos_transfer1/diffusion/README.md#download-checkpoints). Cosmos checkpoints will be downloaded to this directory. diff --git a/cosmos-transfer1/cosmos-transfer1.yaml b/cosmos-transfer1/cosmos-transfer1.yaml deleted file mode 100644 index 7ecc96288f34587dded961dc0134b0dcab16d9a8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos-transfer1.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: cosmos-transfer1 -channels: - - conda-forge -dependencies: - - python=3.12 - - pip=25.0 - - cmake - - ninja - - libgl - - ffmpeg - - gcc=12.4.0 - - gxx=12.4.0 - - cuda=12.4 - - cuda-nvcc=12.4 - - cuda-toolkit=12.4 diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/inference/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/inference/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/inference/depth_anything_pipeline.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/inference/depth_anything_pipeline.py deleted file mode 100644 index 86fd3f3a2747872b0f08d2c3dff6c2e4d01c51ec..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/inference/depth_anything_pipeline.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -from PIL import Image - -from cosmos_transfer1.auxiliary.depth_anything.model.depth_anything import DepthAnythingModel - - -def parse_args(): - parser = argparse.ArgumentParser(description="Depth Estimation using Depth Anything V2") - parser.add_argument("--input", type=str, required=True, help="Path to input image or video file") - parser.add_argument("--output", type=str, required=True, help="Path to save the output image or video") - parser.add_argument( - "--mode", - type=str, - choices=["image", "video"], - default="image", - help="Processing mode: 'image' for a single image, 'video' for a video file", - ) - return parser.parse_args() - - -def main(): - args = parse_args() - model = DepthAnythingModel() - - if args.mode == "image": - # Load the input image and predict its depth - image = Image.open(args.input).convert("RGB") - depth_image = model.predict_depth(image) - depth_image.save(args.output) - print(f"Depth image saved to {args.output}") - elif args.mode == "video": - # Process the video and save the output - out_path = model.predict_depth_video(args.input, args.output) - if out_path: - print(f"Depth video saved to {out_path}") - - -if __name__ == "__main__": - main() diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/model/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/model/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py deleted file mode 100644 index d7af7bf06fe99440985f3303b059e54181e1cb0d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py +++ /dev/null @@ -1,151 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import cv2 -import imageio -import numpy as np -import torch -from PIL import Image -from transformers import AutoImageProcessor, AutoModelForDepthEstimation - -from cosmos_transfer1.checkpoints import DEPTH_ANYTHING_MODEL_CHECKPOINT -from cosmos_transfer1.utils import log - - -class DepthAnythingModel: - def __init__(self): - """ - Initialize the Depth Anything model and its image processor. - """ - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Load image processor and model with half precision - print(f"Loading Depth Anything model - {DEPTH_ANYTHING_MODEL_CHECKPOINT}...") - self.image_processor = AutoImageProcessor.from_pretrained( - DEPTH_ANYTHING_MODEL_CHECKPOINT, - torch_dtype=torch.float16, - trust_remote_code=True, - ) - self.model = AutoModelForDepthEstimation.from_pretrained( - DEPTH_ANYTHING_MODEL_CHECKPOINT, - torch_dtype=torch.float16, - trust_remote_code=True, - ).to(self.device) - - def predict_depth(self, image: Image.Image) -> Image.Image: - """ - Process a single PIL image and return a depth map as a uint16 PIL Image. - """ - # Prepare inputs for the model - inputs = self.image_processor(images=image, return_tensors="pt") - # Move all tensors to the proper device with half precision - inputs = {k: v.to(self.device, dtype=torch.float16) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.model(**inputs) - predicted_depth = outputs.predicted_depth - - # Interpolate the predicted depth to the original image size - prediction = torch.nn.functional.interpolate( - predicted_depth.unsqueeze(1), - size=image.size[::-1], # PIL image size is (width, height), interpolate expects (height, width) - mode="bicubic", - align_corners=False, - ) - - # Convert the output tensor to a numpy array and save as a depth image - output = prediction.squeeze().cpu().numpy() - depth_image = DepthAnythingModel.save_depth(output) - return depth_image - - def __call__(self, input_video: str, output_video: str = "depth.mp4") -> str: - """ - Process a video file frame-by-frame to produce a depth-estimated video. - The output video is saved as an MP4 file. - """ - - log.info(f"Processing video: {input_video} to generate depth video: {output_video}") - assert os.path.exists(input_video) - - cap = cv2.VideoCapture(input_video) - if not cap.isOpened(): - print("Error: Cannot open video file.") - return - - # Retrieve video properties - frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - fps = cap.get(cv2.CAP_PROP_FPS) - - depths = [] - while True: - ret, frame = cap.read() - if not ret: - break - - # Convert frame from BGR to RGB and then to PIL Image - image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - inputs = self.image_processor(images=image, return_tensors="pt") - inputs = {k: v.to(self.device, dtype=torch.float16) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.model(**inputs) - predicted_depth = outputs.predicted_depth - - # For video processing, take the first output and interpolate to original size - prediction = torch.nn.functional.interpolate( - predicted_depth[0].unsqueeze(0).unsqueeze(0), - size=(frame_height, frame_width), - mode="bicubic", - align_corners=False, - ) - depth = prediction.squeeze().cpu().numpy() - depths += [depth] - cap.release() - - depths = np.stack(depths) - depths_normed = (depths - depths.min()) / (depths.max() - depths.min() + 1e-8) * 255.0 - depths_normed = depths_normed.astype(np.uint8) - - os.makedirs(os.path.dirname(output_video), exist_ok=True) - self.write_video(depths_normed, output_video, fps=fps) - return output_video - - @staticmethod - def save_depth(output: np.ndarray) -> Image.Image: - """ - Convert the raw depth output (float values) into a uint16 PIL Image. - """ - depth_min = output.min() - depth_max = output.max() - max_val = (2**16) - 1 # Maximum value for uint16 - - if depth_max - depth_min > np.finfo("float").eps: - out_array = max_val * (output - depth_min) / (depth_max - depth_min) - else: - out_array = np.zeros_like(output) - - formatted = out_array.astype("uint16") - depth_image = Image.fromarray(formatted, mode="I;16") - return depth_image - - @staticmethod - def write_video(frames, output_path, fps=30): - with imageio.get_writer(output_path, fps=fps, macro_block_size=8) as writer: - for frame in frames: - if len(frame.shape) == 2: # single channel - frame = frame[:, :, None].repeat(3, axis=2) - writer.append_data(frame) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/README.md b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/README.md deleted file mode 100644 index b76ea39be0d012157c099704bb1e0d5723ffbf1f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# Cosmos Guardrail - -This page outlines a set of tools to ensure content safety in Cosmos. For implementation details, please consult the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai). - -## Overview - -Our guardrail system consists of two stages: pre-Guard and post-Guard. - -Cosmos pre-Guard models are applied to text input, including input prompts and upsampled prompts. - -* Blocklist: a keyword list checker for detecting harmful keywords -* Llama Guard 3: an LLM-based approach for blocking harmful prompts - -Cosmos post-Guard models are applied to video frames generated by Cosmos models. - -* Video Content Safety Filter: a classifier trained to distinguish between safe and unsafe video frames -* Face Blur Filter: a face detection and blurring module diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/aegis.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/aegis.py deleted file mode 100644 index 03f81880594a610c15bda1b82361c975463c6022..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/aegis.py +++ /dev/null @@ -1,135 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os - -import torch -from peft import PeftModel -from transformers import AutoModelForCausalLM, AutoTokenizer - -from cosmos_transfer1.auxiliary.guardrail.aegis.categories import UNSAFE_CATEGORIES -from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner -from cosmos_transfer1.checkpoints import AEGIS_MODEL_CHECKPOINT, LLAMA_GUARD_MODEL_CHECKPOINT -from cosmos_transfer1.utils import log, misc - -SAFE = misc.Color.green("SAFE") -UNSAFE = misc.Color.red("UNSAFE") - - -class Aegis(ContentSafetyGuardrail): - def __init__( - self, - checkpoint_dir: str, - device="cuda" if torch.cuda.is_available() else "cpu", - ) -> None: - self.checkpoint_dir = checkpoint_dir - self.device = device - self.dtype = torch.bfloat16 - base_model_id = "meta-llama/LlamaGuard-7b" - aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0" - - base_model_dir = os.path.join(self.checkpoint_dir, base_model_id) - aegis_adapter_dir = os.path.join(self.checkpoint_dir, aegis_adapter) - - base_model = AutoModelForCausalLM.from_pretrained(base_model_id, cache_dir=base_model_dir) - self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=base_model_dir) - self.model = PeftModel.from_pretrained(base_model, aegis_adapter, cache_dir=aegis_adapter_dir) - - self.model.to(self.device, dtype=self.dtype).eval() - - def get_moderation_prompt(self, user_prompt: str) -> str: - """Create the moderation prompt for the Aegis model.""" - full_prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. - - - {"".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])} - - - - - """ - full_prompt += f"User: {user_prompt}\n\n" - full_prompt += """ - - Provide your safety assessment for User in the above conversation: - - First line must read 'safe' or 'unsafe'. - - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" - - return full_prompt - - def get_aegis_block_message(self, moderation_output: str) -> str: - """Extract the blocked category and reason from the Aegis model output.""" - block_msg = "Prompt blocked by Aegis." - try: - # Extract blocked category and reason - if len(moderation_output.split("\n")) >= 2: - moderation_output_list = moderation_output.split("\n") - blocked_category = int(moderation_output_list[1].lstrip("O")) - 1 - if 0 <= blocked_category < len(UNSAFE_CATEGORIES): - blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"] - blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"] - block_msg = f"{blocked_category_name}: {blocked_reason}" - except Exception as e: - log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}") - return block_msg - - def filter_aegis_output(self, prompt: str) -> tuple[bool, str]: - """Filter the Aegis model output and return the safety status and message.""" - full_prompt = self.get_moderation_prompt(prompt) - inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device) - output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id) - prompt_len = inputs["input_ids"].shape[-1] - moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) - - if "unsafe" in moderation_output.lower(): - block_msg = self.get_aegis_block_message(moderation_output) - return False, block_msg - else: - return True, "" - - def is_safe(self, prompt: str) -> tuple[bool, str]: - """Check if the input prompt is safe according to the Aegis model.""" - try: - return self.filter_aegis_output(prompt) - except Exception as e: - log.error(f"Unexpected error occurred when running Aegis guardrail: {e}") - return True, "Unexpected error occurred when running Aegis guardrail." - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--prompt", type=str, required=True, help="Input prompt") - parser.add_argument( - "--checkpoint_dir", - type=str, - help="Path to the Aegis checkpoint folder", - default=DEFAULT_CHECKPOINT_DIR, - ) - return parser.parse_args() - - -def main(args): - aegis = Aegis(checkpoint_dir=args.checkpoint_dir) - runner = GuardrailRunner(safety_models=[aegis]) - with misc.timer("aegis safety check"): - safety, message = runner.run_safety_check(args.prompt) - log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") - log.info(f"Message: {message}") if not safety else None - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/categories.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/categories.py deleted file mode 100644 index f9236da33c8d2669b701cc21709925bacbbe6907..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/aegis/categories.py +++ /dev/null @@ -1,192 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -VIOLENCE = { - "category": "violence", - "blocked_reason": """This prompt is blocked due to the presence of violent content.""", - "instruction": """O1: Violence. -Should not -- Help people plan or engage in violence. -- Encourage people to engage in violence. -- Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader. -Can -- Provide information on violence. -- Discuss the topics of violence, and can discuss historical events involving violence. -""", -} - -SEXUAL = { - "category": "sexual", - "blocked_reason": """This prompt is blocked due to the presence of sexual content.""", - "instruction": """O2: Sexual. -Should not -- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. -- Engage or encourage sexual activity that is considered illegal. -Can -- Discuss topics related to sexuality and sexual education. -- Provide general information regarding sexual activities. -- Provide guidance on topics related to sexual health. -""", -} - -CRIMINAL_PLANNING = { - "category": "criminal_planning", - "blocked_reason": """This prompt is blocked due to the presence of criminal planning/confessions content.""", - "instruction": """O3: Criminal Planning/Confessions. -Should not -- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. -- Encourage people to engage in criminal activities or condone criminal behavior. -- Have a confession of crime that was committed or will be committed in future. -Can -- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. -- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems. -""", -} - -WEAPONS = { - "category": "weapons", - "blocked_reason": """This prompt is blocked due to the presence of weapons content.""", - "instruction": """O4: Guns and Illegal Weapons. -Should not -- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. -- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. -- Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them. -Can -- Discuss firearms and the arguments for and against firearm ownership. -""", -} - -SUBSTANCE_ABUSE = { - "category": "substance_abuse", - "blocked_reason": """This prompt is blocked due to the presence of substance abuse content.""", - "instruction": """O5: Controlled/Regulated Substances. -Should not -- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. -- Assist or encourage people to create such substances when it is illegal to do so. -Can -- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). -- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation. -""", -} - -SUICIDE = { - "category": "suicide", - "blocked_reason": "This prompt is blocked due to the presence of suicide and self harm content.", - "instruction": """O6: Suicide and Self Harm. -Should not -- Encourage people to harm themselves. -- Romanticize or condone self-harm or suicide. -- Provide information regarding the methods of suicide or self-harm. -- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. -- Contain confessions or intent of self harm -Can -- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). -Should -- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves. -""", -} - -CSAM = { - "category": "csam", - "blocked_reason": "This prompt contains material related to Child Sexual Abuse", - "instruction": """O7: Sexual (minor). -Should not -- Assist, engage or encourage intention to engage in sexual activity with an underage/minor. -- Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor. -- Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media. -- Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children. -Can -- Provide information about sex education for use in schools. -""", -} - -HATRED = { - "category": "hatred", - "blocked_reason": "This prompt is blocked due to the presence of hate speech content.", - "instruction": """O8: Hate/Identity Hate. -Should not -- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their -race, color, ethnicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste. -- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. -- Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country. -- Should not contain slurs. -- Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear. -Can -- Provide information on discrimination. -- Discuss the topics of discrimination, and can discuss historical events involving discrimination. -- Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate). -""", -} - -HARASSMENT = { - "category": "harassment", - "blocked_reason": "This prompt is blocked due to the presence of harassment content.", - "instruction": """O10: Harassment. -Should not -- Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur. -- Engage in harassment, misconduct and unlawful discrimination -- Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex, -Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled), -Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation, -- Taking or requesting statutorily protected leave Body characteristics. -- Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct. -- Bullying. -- Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law. -- Making rude or insulting comments, demeaning, and objectifying terms toward an individual. -Can -- Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting. -- Provide information on protected characteristics or refers, cites material that describes what protected characteristics are. -- Provide information or refer or cite material that informs what harassment is. -""", -} - -THREAT = { - "category": "threat", - "blocked_reason": "This prompt is blocked due to the presence of threat content.", - "instruction": """O11: Threat. -Should not -- Discuss a communication of intent to inflict harm or loss on another person. -- Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control. -- Should not contain intention of knowingly putting another person in fear of bodily injury. -- Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person. -- Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will. -Can -- Provide information about threat is and the meaning of threat. -""", -} - -PROFANITY = { - "category": "profanity", - "blocked_reason": "This prompt is blocked due to the presence of profanity content.", - "instruction": """O12: Profanity. -Should not -- Contain swear words, curse words, or other obscene or profane language -- This includes offensive words used without any intention to act on them or without any explicit targeted recipient -""", -} - -UNSAFE_CATEGORIES = [ - VIOLENCE, - SEXUAL, - CRIMINAL_PLANNING, - WEAPONS, - SUBSTANCE_ABUSE, - SUICIDE, - CSAM, - HATRED, - HARASSMENT, - THREAT, - PROFANITY, -] diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/blocklist.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/blocklist.py deleted file mode 100644 index 875237aeb0f00f3ce5a7fc4ca620782b3c365650..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/blocklist.py +++ /dev/null @@ -1,216 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import re -import string -from difflib import SequenceMatcher - -import nltk -from better_profanity import profanity - -from cosmos_transfer1.auxiliary.guardrail.blocklist.utils import read_keyword_list_from_dir, to_ascii -from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner -from cosmos_transfer1.utils import log, misc - -CENSOR = misc.Color.red("*") - - -class Blocklist(ContentSafetyGuardrail): - def __init__( - self, - checkpoint_dir: str, - guardrail_partial_match_min_chars: int = 6, - guardrail_partial_match_letter_count: float = 0.4, - ) -> None: - self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/blocklist") - nltk.data.path.append(os.path.join(self.checkpoint_dir, "nltk_data")) - self.lemmatizer = nltk.WordNetLemmatizer() - self.profanity = profanity - self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars - self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count - - # Load blocklist and whitelist keywords - self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom")) - self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist")) - self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match")) - - self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words) - log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist") - log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist") - log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist") - - def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str: - """Explicitly uncensor words that are in the whitelist.""" - input_words = input_prompt.split() - censored_words = censored_prompt.split() - whitelist_words = set(self.whitelist_words) - for i, token in enumerate(input_words): - if token.strip(string.punctuation).lower() in whitelist_words: - censored_words[i] = token - censored_prompt = " ".join(censored_words) - return censored_prompt - - def censor_prompt(self, input_prompt: str) -> tuple[bool, str]: - """Censor the prompt using the blocklist with better-profanity fuzzy matching. - - Args: - input_prompt: input prompt to censor - - Returns: - bool: True if the prompt is blocked, False otherwise - str: A message indicating why the prompt was blocked - """ - censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR) - # Uncensor whitelisted words that were censored from blocklist fuzzy matching - censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt) - if CENSOR in censored_prompt: - return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}" - return False, "" - - @staticmethod - def check_partial_match( - normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float - ) -> tuple[bool, str]: - """ - Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters. - - Args: - normalized_prompt: a string with many words - normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt - guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters) - - Returns: - bool: True if a match is found, False otherwise - str: A message indicating why the prompt was blocked - """ - prompt_words = normalized_prompt.split() - word_length = len(normalized_word.split()) - max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float( - len(normalized_word) - ) - - for i in range(len(prompt_words) - word_length + 1): - # Extract a substring from the prompt with the same number of words as the normalized_word - substring = " ".join(prompt_words[i : i + word_length]) - similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio() - if similarity_ratio >= max_similarity_ratio: - return ( - True, - f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}", - ) - - return False, "" - - @staticmethod - def check_against_whole_word_blocklist( - prompt: str, - blocklist: list[str], - guardrail_partial_match_min_chars: int = 6, - guardrail_partial_match_letter_count: float = 0.4, - ) -> bool: - """ - Check if the prompt contains any whole words from the blocklist. - The match is case insensitive and robust to multiple spaces between words. - - Args: - prompt: input prompt to check - blocklist: list of words to check against - guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match - guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match - - Returns: - bool: True if a match is found, False otherwise - str: A message indicating why the prompt was blocked - """ - # Normalize spaces and convert to lowercase - normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower() - - for word in blocklist: - # Normalize spaces and convert to lowercase for each blocklist word - normalized_word = re.sub(r"\s+", " ", word).strip().lower() - - # Use word boundaries to ensure whole word match - if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt): - return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}" - - # Check for partial match if the word is long enough - if len(normalized_word) >= guardrail_partial_match_min_chars: - match, message = Blocklist.check_partial_match( - normalized_prompt, normalized_word, guardrail_partial_match_letter_count - ) - if match: - return True, message - - return False, "" - - def is_safe(self, input_prompt: str = "") -> tuple[bool, str]: - """Check if the input prompt is safe using the blocklist.""" - # Check if the input is empty - if not input_prompt: - return False, "Input is empty" - input_prompt = to_ascii(input_prompt) - - # Check full sentence for censored words - censored, message = self.censor_prompt(input_prompt) - if censored: - return False, message - - # Check lemmatized words for censored words - tokens = nltk.word_tokenize(input_prompt) - lemmas = [self.lemmatizer.lemmatize(token) for token in tokens] - lemmatized_prompt = " ".join(lemmas) - censored, message = self.censor_prompt(lemmatized_prompt) - if censored: - return False, message - - # Check for exact match blocklist words - censored, message = self.check_against_whole_word_blocklist( - input_prompt, - self.exact_match_words, - self.guardrail_partial_match_min_chars, - self.guardrail_partial_match_letter_count, - ) - if censored: - return False, message - - # If all these checks pass, the input is safe - return True, "Input is safe" - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--prompt", type=str, required=True, help="Input prompt") - parser.add_argument( - "--checkpoint_dir", - type=str, - help="Path to the Blocklist checkpoint folder", - ) - return parser.parse_args() - - -def main(args): - blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir) - runner = GuardrailRunner(safety_models=[blocklist]) - with misc.timer("blocklist safety check"): - safety, message = runner.run_safety_check(args.prompt) - log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") - log.info(f"Message: {message}") if not safety else None - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/utils.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/utils.py deleted file mode 100644 index a8991621958bb86dff5a1d1f13bd6d7e7276e7fa..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/blocklist/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re - -from cosmos_transfer1.utils import log - - -def read_keyword_list_from_dir(folder_path: str) -> list[str]: - """Read keyword list from all files in a folder.""" - output_list = [] - file_list = [] - # Get list of files in the folder - for file in os.listdir(folder_path): - if os.path.isfile(os.path.join(folder_path, file)): - file_list.append(file) - - # Process each file - for file in file_list: - file_path = os.path.join(folder_path, file) - try: - with open(file_path, "r") as f: - output_list.extend([line.strip() for line in f.readlines()]) - except Exception as e: - log.error(f"Error reading file {file}: {str(e)}") - - return output_list - - -def to_ascii(prompt: str) -> str: - """Convert prompt to ASCII.""" - return re.sub(r"[^\x00-\x7F]+", " ", prompt) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/core.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/core.py deleted file mode 100644 index 9f333e2abbb5326452d9334dbedcfb42113a277e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/core.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Tuple - -import numpy as np - -from cosmos_transfer1.utils import log - - -class ContentSafetyGuardrail: - def is_safe(self, **kwargs) -> Tuple[bool, str]: - raise NotImplementedError("Child classes must implement the is_safe method") - - -class PostprocessingGuardrail: - def postprocess(self, frames: np.ndarray) -> np.ndarray: - raise NotImplementedError("Child classes must implement the postprocess method") - - -class GuardrailRunner: - def __init__( - self, - safety_models: list[ContentSafetyGuardrail] | None = None, - generic_block_msg: str = "", - generic_safe_msg: str = "", - postprocessors: list[PostprocessingGuardrail] | None = None, - ): - self.safety_models = safety_models - self.generic_block_msg = generic_block_msg - self.generic_safe_msg = generic_safe_msg if generic_safe_msg else "Prompt is safe" - self.postprocessors = postprocessors - - def run_safety_check(self, input: Any) -> Tuple[bool, str]: - """Run the safety check on the input.""" - if not self.safety_models: - log.warning("No safety models found, returning safe") - return True, self.generic_safe_msg - - for guardrail in self.safety_models: - guardrail_name = str(guardrail.__class__.__name__).upper() - log.debug(f"Running guardrail: {guardrail_name}") - safe, message = guardrail.is_safe(input) - if not safe: - reasoning = self.generic_block_msg if self.generic_block_msg else f"{guardrail_name}: {message}" - return False, reasoning - return True, self.generic_safe_msg - - def postprocess(self, frames: np.ndarray) -> np.ndarray: - """Run the postprocessing on the video frames.""" - if not self.postprocessors: - log.warning("No postprocessors found, returning original frames") - return frames - - for guardrail in self.postprocessors: - guardrail_name = str(guardrail.__class__.__name__).upper() - log.debug(f"Running guardrail: {guardrail_name}") - frames = guardrail.postprocess(frames) - return frames diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/io_utils.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/io_utils.py deleted file mode 100644 index 9f79410b0af6450a5016712e97197c38a4ed8873..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/io_utils.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import glob -from dataclasses import dataclass - -import imageio -import numpy as np - -from cosmos_transfer1.utils import log - - -@dataclass -class VideoData: - frames: np.ndarray # Shape: [B, H, W, C] - fps: int - duration: int # in seconds - - -def get_video_filepaths(input_dir: str) -> list[str]: - """Get a list of filepaths for all videos in the input directory.""" - paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True) - paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True) - paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True) - paths = sorted(paths) - log.debug(f"Found {len(paths)} videos") - return paths - - -def read_video(filepath: str) -> VideoData: - """Read a video file and extract its frames and metadata.""" - try: - reader = imageio.get_reader(filepath, "ffmpeg") - except Exception as e: - raise ValueError(f"Failed to read video file: {filepath}") from e - - # Extract metadata from the video file - try: - metadata = reader.get_meta_data() - fps = metadata.get("fps") - duration = metadata.get("duration") - except Exception as e: - reader.close() - raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e - - # Extract frames from the video file - try: - frames = np.array([frame for frame in reader]) - except Exception as e: - raise ValueError(f"Failed to extract frames from video file: {filepath}") from e - finally: - reader.close() - - return VideoData(frames=frames, fps=fps, duration=duration) - - -def save_video(filepath: str, frames: np.ndarray, fps: int) -> None: - """Save a video file from a sequence of frames.""" - try: - writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1) - for frame in frames: - writer.append_data(frame) - except Exception as e: - raise ValueError(f"Failed to save video file to {filepath}") from e - finally: - writer.close() diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/presets.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/presets.py deleted file mode 100644 index ac6c711e8d39047c527d46a219520a30a237db4f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/common/presets.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import numpy as np - -from cosmos_transfer1.auxiliary.guardrail.blocklist.blocklist import Blocklist -from cosmos_transfer1.auxiliary.guardrail.common.core import GuardrailRunner -from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter -from cosmos_transfer1.auxiliary.guardrail.llamaGuard3.llamaGuard3 import LlamaGuard3 -from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.video_content_safety_filter import ( - VideoContentSafetyFilter, -) -from cosmos_transfer1.utils import log - - -def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: - """Create the text guardrail runner.""" - return GuardrailRunner(safety_models=[Blocklist(checkpoint_dir), LlamaGuard3(checkpoint_dir)]) - - -def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner: - """Create the video guardrail runner.""" - return GuardrailRunner( - safety_models=[VideoContentSafetyFilter(checkpoint_dir)], - postprocessors=[RetinaFaceFilter(checkpoint_dir)], - ) - - -def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool: - """Run the text guardrail on the prompt, checking for content safety. - - Args: - prompt: The text prompt. - guardrail_runner: The text guardrail runner. - - Returns: - bool: Whether the prompt is safe. - """ - is_safe, message = guardrail_runner.run_safety_check(prompt) - if not is_safe: - log.critical(f"GUARDRAIL BLOCKED: {message}") - return is_safe - - -def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None: - """Run the video guardrail on the frames, checking for content safety and applying face blur. - - Args: - frames: The frames of the generated video. - guardrail_runner: The video guardrail runner. - - Returns: - The processed frames if safe, otherwise None. - """ - is_safe, message = guardrail_runner.run_safety_check(frames) - if not is_safe: - log.critical(f"GUARDRAIL BLOCKED: {message}") - return None - - frames = guardrail_runner.postprocess(frames) - return frames diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/blur_utils.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/blur_utils.py deleted file mode 100644 index d52f69d220444a53027b3b4acc3bd192fc6eb76f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/blur_utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cv2 -import numpy as np - - -def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray: - """ - Pixelate a face region by reducing resolution and then upscaling. - - Args: - face_img: Face region to pixelate - blocks: Number of blocks to divide the face into (in each dimension) - - Returns: - Pixelated face region - """ - h, w = face_img.shape[:2] - # Shrink the image and scale back up to create pixelation effect - temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR) - pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST) - return pixelated diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py deleted file mode 100644 index 2661101bf68e2a2f8b25e666c1067c4287f2286a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py +++ /dev/null @@ -1,225 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os - -import numpy as np -import torch -from retinaface.data import cfg_re50 -from retinaface.layers.functions.prior_box import PriorBox -from retinaface.models.retinaface import RetinaFace -from torch.utils.data import DataLoader, TensorDataset -from tqdm import tqdm - -from cosmos_transfer1.auxiliary.guardrail.common.core import GuardrailRunner, PostprocessingGuardrail -from cosmos_transfer1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video, save_video -from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.blur_utils import pixelate_face -from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.retinaface_utils import ( - decode_batch, - filter_detected_boxes, - load_model, -) -from cosmos_transfer1.utils import log, misc - -# RetinaFace model constants from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py -TOP_K = 5_000 -KEEP_TOP_K = 750 -NMS_THRESHOLD = 0.4 - - -class RetinaFaceFilter(PostprocessingGuardrail): - def __init__( - self, - checkpoint_dir: str, - batch_size: int = 1, - confidence_threshold: float = 0.7, - device="cuda" if torch.cuda.is_available() else "cpu", - ) -> None: - """ - Initialize the RetinaFace model for face detection and blurring. - - Args: - checkpoint: Path to the RetinaFace checkpoint file - batch_size: Batch size for RetinaFace inference and processing - confidence_threshold: Minimum confidence score to consider a face detection - """ - self.checkpoint = f"{checkpoint_dir}/nvidia/Cosmos-Guardrail1/face_blur_filter/Resnet50_Final.pth" - self.cfg = cfg_re50 - self.batch_size = batch_size - self.confidence_threshold = confidence_threshold - self.device = device - self.dtype = torch.float32 - - # Disable loading ResNet pretrained weights - self.cfg["pretrain"] = False - self.net = RetinaFace(cfg=self.cfg, phase="test") - cpu = self.device == "cpu" - - # Load from RetinaFace pretrained checkpoint - self.net = load_model(self.net, self.checkpoint, cpu) - self.net.to(self.device, dtype=self.dtype).eval() - - def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor: - """Preprocess a sequence of frames for face detection. - - Args: - frames: Input frames - - Returns: - Preprocessed frames tensor - """ - with torch.no_grad(): - frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) # Shape: [T, H, W, C] - frames_tensor = frames_tensor.permute(0, 3, 1, 2) # Shape: [T, C, H, W] - frames_tensor = frames_tensor[:, [2, 1, 0], :, :] # RGB to BGR to match RetinaFace model input - means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1) - frames_tensor = frames_tensor - means # Subtract mean BGR values for each channel - return frames_tensor - - def blur_detected_faces( - self, - frames: np.ndarray, - batch_loc: torch.Tensor, - batch_conf: torch.Tensor, - prior_data: torch.Tensor, - scale: torch.Tensor, - min_size: tuple[int] = (20, 20), - ) -> list[np.ndarray]: - """Blur detected faces in a batch of frames using RetinaFace predictions. - - Args: - frames: Input frames - batch_loc: Batched location predictions - batch_conf: Batched confidence scores - prior_data: Prior boxes for the video - scale: Scale factor for resizing detections - min_size: Minimum size of a detected face region in pixels - - Returns: - Processed frames with pixelated faces - """ - with torch.no_grad(): - batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"]) - batch_boxes = batch_boxes * scale - - blurred_frames = [] - for i, boxes in enumerate(batch_boxes): - boxes = boxes.detach().cpu().numpy() - scores = batch_conf[i, :, 1].detach().cpu().numpy() - - filtered_boxes = filter_detected_boxes( - boxes, - scores, - confidence_threshold=self.confidence_threshold, - nms_threshold=NMS_THRESHOLD, - top_k=TOP_K, - keep_top_k=KEEP_TOP_K, - ) - - frame = frames[i] - for box in filtered_boxes: - x1, y1, x2, y2 = map(int, box) - # Ignore bounding boxes smaller than the minimum size - if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]: - continue - max_h, max_w = frame.shape[:2] - face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] - blurred_face = pixelate_face(face_roi) - frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face - blurred_frames.append(frame) - - return blurred_frames - - def postprocess(self, frames: np.ndarray) -> np.ndarray: - """Blur faces in a sequence of frames. - - Args: - frames: Input frames - - Returns: - Processed frames with pixelated faces - """ - # Create dataset and dataloader - frames_tensor = self.preprocess_frames(frames) - dataset = TensorDataset(frames_tensor) - dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) - processed_frames, processed_batches = [], [] - - prior_data, scale = None, None - for i, batch in enumerate(dataloader): - batch = batch[0] - h, w = batch.shape[-2:] # Batch shape: [C, H, W] - - with torch.no_grad(): - # Generate priors for the video - if prior_data is None: - priorbox = PriorBox(self.cfg, image_size=(h, w)) - priors = priorbox.forward() - priors = priors.to(self.device, dtype=self.dtype) - prior_data = priors.data - - # Get scale for resizing detections - if scale is None: - scale = torch.Tensor([w, h, w, h]) - scale = scale.to(self.device, dtype=self.dtype) - - batch_loc, batch_conf, _ = self.net(batch) - - # Blur detected faces in each batch of frames - start_idx = i * self.batch_size - end_idx = min(start_idx + self.batch_size, len(frames)) - processed_batches.append( - self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale) - ) - - processed_frames = [frame for batch in processed_batches for frame in batch] - return np.array(processed_frames) - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") - parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos") - parser.add_argument( - "--checkpoint", - type=str, - help="Path to the RetinaFace checkpoint file", - ) - return parser.parse_args() - - -def main(args): - filepaths = get_video_filepaths(args.input_dir) - if not filepaths: - log.error(f"No video files found in directory: {args.input_dir}") - return - - face_blur = RetinaFaceFilter(checkpoint=args.checkpoint) - postprocessing_runner = GuardrailRunner(postprocessors=[face_blur]) - os.makedirs(args.output_dir, exist_ok=True) - - for filepath in tqdm(filepaths): - video_data = read_video(filepath) - with misc.timer("face blur filter"): - frames = postprocessing_runner.postprocess(video_data.frames) - - output_path = os.path.join(args.output_dir, os.path.basename(filepath)) - save_video(output_path, frames, video_data.fps) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py deleted file mode 100644 index 7684648f3a4a2a48f787b4c0c2880ff7b436035f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -from retinaface.utils.nms.py_cpu_nms import py_cpu_nms - -from cosmos_transfer1.utils import log - - -# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py -def filter_detected_boxes(boxes, scores, confidence_threshold, nms_threshold, top_k, keep_top_k): - """Filter boxes based on confidence score and remove overlapping boxes using NMS.""" - # Keep detections with confidence above threshold - inds = np.where(scores > confidence_threshold)[0] - boxes = boxes[inds] - scores = scores[inds] - - # Sort by confidence and keep top K detections - order = scores.argsort()[::-1][:top_k] - boxes = boxes[order] - scores = scores[order] - - # Run non-maximum-suppression (NMS) to remove overlapping boxes - dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) - keep = py_cpu_nms(dets, nms_threshold) - dets = dets[keep, :] - dets = dets[:keep_top_k, :] - boxes = dets[:, :-1] - return boxes - - -# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py to handle batched inputs -def decode_batch(loc, priors, variances): - """Decode batched locations from predictions using priors and variances. - - Args: - loc (tensor): Batched location predictions for loc layers. - Shape: [batch_size, num_priors, 4] - priors (tensor): Prior boxes in center-offset form. - Shape: [num_priors, 4] - variances: (list[float]): Variances of prior boxes. - - Return: - Decoded batched bounding box predictions - Shape: [batch_size, num_priors, 4] - """ - batch_size = loc.size(0) - priors = priors.unsqueeze(0).expand(batch_size, -1, -1) - - boxes = torch.cat( - ( - priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], - priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]), - ), - dim=2, - ) - - boxes[:, :, :2] -= boxes[:, :, 2:] / 2 - boxes[:, :, 2:] += boxes[:, :, :2] - return boxes - - -# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py -def _check_keys(model, pretrained_state_dict): - ckpt_keys = set(pretrained_state_dict.keys()) - model_keys = set(model.state_dict().keys()) - used_pretrained_keys = model_keys & ckpt_keys - unused_pretrained_keys = ckpt_keys - model_keys - missing_keys = model_keys - ckpt_keys - log.debug("Missing keys:{}".format(len(missing_keys))) - log.debug("Unused checkpoint keys:{}".format(len(unused_pretrained_keys))) - log.debug("Used keys:{}".format(len(used_pretrained_keys))) - assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint" - return True - - -# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py -def _remove_prefix(state_dict, prefix): - """Old version of the model is stored with all names of parameters sharing common prefix 'module.'""" - log.debug("Removing prefix '{}'".format(prefix)) - - def f(x): - return x.split(prefix, 1)[-1] if x.startswith(prefix) else x - - return {f(key): value for key, value in state_dict.items()} - - -# Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py -def load_model(model, pretrained_path, load_to_cpu): - log.debug("Loading pretrained model from {}".format(pretrained_path)) - if load_to_cpu: - pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage, weights_only=True) - else: - device = torch.cuda.current_device() - pretrained_dict = torch.load( - pretrained_path, map_location=lambda storage, loc: storage.cuda(device), weights_only=True - ) - if "state_dict" in pretrained_dict.keys(): - pretrained_dict = _remove_prefix(pretrained_dict["state_dict"], "module.") - else: - pretrained_dict = _remove_prefix(pretrained_dict, "module.") - _check_keys(model, pretrained_dict) - model.load_state_dict(pretrained_dict, strict=False) - return model diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/categories.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/categories.py deleted file mode 100644 index f8d5a95d4dce1202e3acec0e10177c97c1e5924e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/categories.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -UNSAFE_CATEGORIES = { - "S1": "Violent Crimes.", - "S2": "Non-Violent Crimes.", - "S3": "Sex Crimes.", - "S4": "Child Exploitation.", - "S5": "Defamation.", - "S6": "Specialized Advice.", - "S7": "Privacy.", - "S8": "Intellectual Property.", - "S9": "Indiscriminate Weapons.", - "S10": "Hate.", - "S11": "Self-Harm.", - "S12": "Sexual Content.", - "S13": "Elections.", - "s14": "Code Interpreter Abuse.", -} diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py deleted file mode 100644 index a0fc1d37b6f94ad1d9336c11ccbd6416e509a7c3..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner -from cosmos_transfer1.auxiliary.guardrail.llamaGuard3.categories import UNSAFE_CATEGORIES -from cosmos_transfer1.utils import log, misc - -SAFE = misc.Color.green("SAFE") -UNSAFE = misc.Color.red("UNSAFE") - - -class LlamaGuard3(ContentSafetyGuardrail): - def __init__( - self, - checkpoint_dir: str, - device="cuda" if torch.cuda.is_available() else "cpu", - ) -> None: - self.checkpoint_dir = checkpoint_dir - self.device = device - self.dtype = torch.bfloat16 - - model_id = "meta-llama/Llama-Guard-3-8B" - model_dir = os.path.join(self.checkpoint_dir, model_id) - - self.model = AutoModelForCausalLM.from_pretrained(model_dir) - self.tokenizer = AutoTokenizer.from_pretrained(model_dir) - - self.model.to(self.device, dtype=self.dtype).eval() - - def get_llamaGuard3_block_message(self, moderation_output: str) -> str: - """Extract the blocked category from the Llama Guard 3 model output.""" - block_msg = "Prompt blocked by Llama Guard 3." - try: - lines = moderation_output.splitlines() - categories_detected = [] - for line in lines[1:]: - line_stripped = line.split("<|eot_id|>")[0].strip() - for catagory in line_stripped.split(","): - catagory = catagory.strip() - if catagory not in UNSAFE_CATEGORIES: - log.warning(f"Unrecognized category from moderation output: {catagory}") - else: - categories_detected.append(catagory) - if len(categories_detected) > 0: - blocked_catagories = ", ".join([UNSAFE_CATEGORIES[catagory][:-1] for catagory in categories_detected]) - block_msg = f"{block_msg} Violations: {blocked_catagories}." - except Exception as e: - log.warning(f"Unable to extract blocked category from Llama Guard 3 output: {e}") - return block_msg - - def filter_llamaGuard3_output(self, prompt: str) -> tuple[bool, str]: - """Filter the Llama Guard 3 model output and return the safety status and message.""" - conversation = [{"role": "user", "content": prompt}] - input_ids = self.tokenizer.apply_chat_template( - conversation, categories=UNSAFE_CATEGORIES, return_tensors="pt" - ).to("cuda") - prompt_len = input_ids.shape[1] - output = self.model.generate( - input_ids=input_ids, - max_new_tokens=100, - return_dict_in_generate=True, - pad_token_id=0, - ) - generated_tokens = output.sequences[:, prompt_len:] - moderation_output = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=False).strip() - - if "unsafe" in moderation_output.lower(): - block_msg = self.get_llamaGuard3_block_message(moderation_output) - return False, block_msg - else: - return True, "" - - def is_safe(self, prompt: str) -> tuple[bool, str]: - """Check if the input prompt is safe according to the Llama Guard 3 model.""" - try: - return self.filter_llamaGuard3_output(prompt) - except Exception as e: - log.error(f"Unexpected error occurred when running Llama Guard 3 guardrail: {e}") - return True, "Unexpected error occurred when running Llama Guard 3 guardrail." - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--prompt", type=str, required=True, help="Input prompt") - parser.add_argument( - "--checkpoint_dir", - type=str, - help="Path to the Llama Guard 3 checkpoint folder", - ) - return parser.parse_args() - - -def main(args): - llamaGuard3 = LlamaGuard3(checkpoint_dir=args.checkpoint_dir) - runner = GuardrailRunner(safety_models=[llamaGuard3]) - with misc.timer("Llama Guard 3 safety check"): - safety, message = runner.run_safety_check(args.prompt) - log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}") - log.info(f"Message: {message}") if not safety else None - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py deleted file mode 100644 index 2c9cf962d74d213fea7d33faad0a4a8b3d9b17f4..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import attrs -import torch -import torch.nn as nn - -from cosmos_transfer1.utils.ddp_config import make_freezable - - -@make_freezable -@attrs.define(slots=False) -class ModelConfig: - input_size: int = 1152 - num_classes: int = 7 - - -class SafetyClassifier(nn.Module): - def __init__(self, input_size: int = 1024, num_classes: int = 2): - super().__init__() - self.input_size = input_size - self.num_classes = num_classes - self.layers = nn.Sequential( - nn.Linear(self.input_size, 512), - nn.BatchNorm1d(512), - nn.ReLU(), - nn.Linear(512, 256), - nn.BatchNorm1d(256), - nn.ReLU(), - nn.Linear(256, self.num_classes), - # Note: No activation function here; CrossEntropyLoss expects raw logits - ) - - def forward(self, x): - return self.layers(x) - - -class VideoSafetyModel(nn.Module): - def __init__(self, config: ModelConfig) -> None: - super().__init__() - self.config = config - self.num_classes = config.num_classes - self.network = SafetyClassifier(input_size=config.input_size, num_classes=self.num_classes) - - @torch.inference_mode() - def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - logits = self.network(data_batch["data"].cuda()) - return {"logits": logits} diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py deleted file mode 100644 index 77f62f553d0a7db4eced90b383da732182830eef..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py +++ /dev/null @@ -1,185 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import os -from typing import Iterable, Tuple, Union - -import torch -from PIL import Image - -from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner -from cosmos_transfer1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video -from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.model import ModelConfig, VideoSafetyModel -from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.vision_encoder import SigLIPEncoder -from cosmos_transfer1.utils import log, misc - -# Define the class index to class name mapping for multi-class classification -CLASS_IDX_TO_NAME = { - 0: "Safe", - 1: "Sexual_Content", - 3: "Drugs", - 4: "Child_Abuse", - 5: "Hate_and_Harassment", - 6: "Self-Harm", -} - - -class VideoContentSafetyFilter(ContentSafetyGuardrail): - def __init__( - self, - checkpoint_dir: str, - device="cuda" if torch.cuda.is_available() else "cpu", - ) -> None: - self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/video_content_safety_filter") - self.device = device - self.dtype = torch.float32 - - # Initialize the SigLIP encoder - self.encoder = SigLIPEncoder(checkpoint_dir=self.checkpoint_dir, device=device, dtype=self.dtype) - - # Use ModelConfig directly for inference configuration - model_config = ModelConfig(input_size=1152, num_classes=7) - - # Load the multi-class classifier - self.model = VideoSafetyModel(model_config) - safety_filter_local_path = os.path.join(self.checkpoint_dir, "safety_filter.pt") - checkpoint = torch.load(safety_filter_local_path, map_location=torch.device("cpu"), weights_only=True) - self.model.load_state_dict(checkpoint["model"]) - self.model.to(self.device, dtype=self.dtype).eval() - - @torch.inference_mode() - def __infer(self, pil_image: Image.Image) -> int: - """Infer the class of the image.""" - image_embs = self.encoder.encode_image(pil_image) - logits = self.model.network(image_embs) - probabilities = torch.nn.functional.softmax(logits, dim=-1) - predicted_class = torch.argmax(probabilities, dim=-1).item() - return predicted_class - - def is_safe_file(self, filepath: str) -> bool: - """Check if the video file is safe.""" - video_data = read_video(filepath) - - # Sample frames at 2 FPS - sample_rate = 2 # frames per second - frame_interval = int(video_data.fps / sample_rate) - frame_numbers = list(range(0, int(video_data.fps * video_data.duration), frame_interval)) - - is_safe = True - frame_scores = [] - - for frame_number in frame_numbers: - try: - frame = video_data.frames[frame_number] - pil_image = Image.fromarray(frame) - predicted_class = self.__infer(pil_image) - class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Safe") - frame_scores.append({"frame_number": frame_number, "class": class_name}) - - # If any frame is not "Safe", mark the video as unsafe - if class_name != "Safe": - is_safe = False - break - - except Exception as e: - log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") - continue - - # Prepare data for JSON - video_data = { - "filepath": filepath, - "is_safe": is_safe, - "video_length": video_data.duration, - "fps": video_data.fps, - "frame_scores": frame_scores, - } - - log.info(f"Video {filepath} is {'SAFE' if is_safe else 'UNSAFE'}.") - log.debug(f"Video data: {json.dumps(video_data, indent=4)}") - return is_safe - - def is_safe_frames(self, frames: Iterable) -> bool: - """Check if the generated video frames are safe.""" - frame_scores = [] - total_frames = 0 - safe_frames = 0 - - for frame_number, frame in enumerate(frames): - try: - total_frames += 1 - pil_image = Image.fromarray(frame) - predicted_class = self.__infer(pil_image) - class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Safe") - frame_scores.append({"frame_number": frame_number, "class": class_name}) - - if class_name == "Safe": - safe_frames += 1 - - except Exception as e: - log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}") - continue - - # Decide if the video is safe based on the ratio of safe frames - is_safe = False - if total_frames > 0: - is_safe = (safe_frames / total_frames) >= 0.95 - - video_data = { - "is_safe": is_safe, - "frame_scores": frame_scores, - } - - log.debug(f"Frames data: {json.dumps(video_data, indent=4)}") - return is_safe - - def is_safe(self, input: Union[str, Iterable]) -> Tuple[bool, str]: - if isinstance(input, str): - is_safe = self.is_safe_file(input) - return is_safe, "safe video detected" if is_safe else "unsafe video detected" - else: - is_safe = self.is_safe_frames(input) - return is_safe, "safe frames detected" if is_safe else "unsafe frames detected" - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") - parser.add_argument( - "--checkpoint_dir", - type=str, - help="Path to the Video Content Safety Filter checkpoint folder", - ) - return parser.parse_args() - - -def main(args): - filepaths = get_video_filepaths(args.input_dir) - if not filepaths: - log.error(f"No video files found in directory: {args.input_dir}") - return - - video_filter = VideoContentSafetyFilter(checkpoint_dir=args.checkpoint_dir) - runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe") - - for filepath in filepaths: - with misc.timer("video content safety filter"): - _ = runner.run_safety_check(filepath) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py deleted file mode 100644 index c5c3232ebe19bc7756a068c123022116c677a9d8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch -from PIL import Image -from transformers import SiglipModel, SiglipProcessor - - -class SigLIPEncoder(torch.nn.Module): - def __init__( - self, - checkpoint_dir: str, - model_name: str = "google/siglip-so400m-patch14-384", - device="cuda" if torch.cuda.is_available() else "cpu", - dtype=torch.float32, - ) -> None: - super().__init__() - self.checkpoint_dir = checkpoint_dir - self.device = device - self.dtype = dtype - self.model = SiglipModel.from_pretrained(model_name, cache_dir=self.checkpoint_dir) - self.processor = SiglipProcessor.from_pretrained(model_name, cache_dir=self.checkpoint_dir) - self.model.to(self.device, dtype=self.dtype).eval() - - @torch.inference_mode() - def encode_image(self, input_img: Image.Image) -> torch.Tensor: - """Encode an image into a feature vector.""" - with torch.no_grad(): - inputs = self.processor(images=input_img, return_tensors="pt").to(self.device, dtype=self.dtype) - image_features = self.model.get_image_features(**inputs) - image_features /= image_features.norm(dim=-1, keepdim=True) - return image_features diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/human_keypoint/human_keypoint.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/human_keypoint/human_keypoint.py deleted file mode 100644 index 6b340045495bf159f5a7698300fbe4f1adad963f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/human_keypoint/human_keypoint.py +++ /dev/null @@ -1,155 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import cv2 -import numpy as np -from rtmlib import Wholebody - -from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import ( - coco_wholebody_133_skeleton, - openpose134_skeleton, -) -from cosmos_transfer1.utils import log - - -class HumanKeypointModel: - def __init__(self, to_openpose=True, conf_thres=0.6): - self.model = Wholebody( - to_openpose=to_openpose, - mode="performance", - backend="onnxruntime", - device="cuda", - ) - self.to_openpose = to_openpose - self.conf_thres = conf_thres - - def __call__(self, input_video: str, output_video: str = "keypoint.mp4") -> str: - """ - Generate the human body keypoint plot for the keypointControlNet video2world model. - Input: mp4 video - Output: mp4 keypoint video, of the same spatial and temporal dimensions as the input video. - """ - - log.info(f"Processing video: {input_video} to generate keypoint video: {output_video}") - assert os.path.exists(input_video) - - cap = cv2.VideoCapture(input_video) - fps = int(cap.get(cv2.CAP_PROP_FPS)) - frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - frame_size = (frame_width, frame_height) - - # vid writer - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - skeleton_writer = cv2.VideoWriter(output_video, fourcc, fps, frame_size) - - log.info(f"frame width: {frame_width}, frame height: {frame_height}, fps: {fps}") - log.info("start pose estimation for frames..") - - # Process each frame - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break - - # Create a black background frame - black_frame = np.zeros_like(frame) - - # Run pose estimation - keypoints, scores = self.model(frame) - - if keypoints is not None and len(keypoints) > 0: - skeleton_frame = self.plot_person_kpts( - black_frame, - keypoints, - scores, - kpt_thr=self.conf_thres, - openpose_format=True, - line_width=4, - ) # (h, w, 3) - else: - skeleton_frame = black_frame - - skeleton_writer.write(skeleton_frame[:, :, ::-1]) - - cap.release() - skeleton_writer.release() - - def draw_skeleton( - self, - img: np.ndarray, - keypoints: np.ndarray, - scores: np.ndarray, - kpt_thr: float = 0.6, - openpose_format: bool = True, - radius: int = 2, - line_width: int = 4, - ): - skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton - assert len(keypoints.shape) == 2 - keypoint_info, skeleton_info = ( - skeleton_topology["keypoint_info"], - skeleton_topology["skeleton_info"], - ) - vis_kpt = [s >= kpt_thr for s in scores] - link_dict = {} - for i, kpt_info in keypoint_info.items(): - kpt_color = tuple(kpt_info["color"]) - link_dict[kpt_info["name"]] = kpt_info["id"] - - kpt = keypoints[i] - - if vis_kpt[i]: - img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1) - - for i, ske_info in skeleton_info.items(): - link = ske_info["link"] - pt0, pt1 = link_dict[link[0]], link_dict[link[1]] - - if vis_kpt[pt0] and vis_kpt[pt1]: - link_color = ske_info["color"] - kpt0 = keypoints[pt0] - kpt1 = keypoints[pt1] - - img = cv2.line( - img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width - ) - - return img - - def plot_person_kpts( - self, - pose_vis_img: np.ndarray, - keypoints: np.ndarray, - scores: np.ndarray, - kpt_thr: float = 0.6, - openpose_format: bool = True, - line_width: int = 4, - ) -> np.ndarray: - """ - plot a single person - in-place update the pose image - """ - for kpts, ss in zip(keypoints, scores): - try: - pose_vis_img = self.draw_skeleton( - pose_vis_img, kpts, ss, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width - ) - except ValueError as e: - log.error(f"Error in draw_skeleton func, {e}") - - return pose_vis_img diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/robot_augmentation/README.md b/cosmos-transfer1/cosmos_transfer1/auxiliary/robot_augmentation/README.md deleted file mode 100644 index 13a77710806abd177c4360ad4687e62e9265e8e3..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/robot_augmentation/README.md +++ /dev/null @@ -1,112 +0,0 @@ -# Robot Data Augmentation with Cosmos-Transfer1 - -This pipeline provides a two-step process to augment robotic videos using **Cosmos-Transfer1-7B**. It leverages **spatial-temporal control** to modify backgrounds while preserving the shape and/or appearance of the robot foreground. - -## Overview of Settings - -We propose two augmentation settings: - -### Setting 1 (fg_vis_edge_bg_seg): Preserve Shape and Appearance of the Robot (foreground) -- **Foreground Controls**: `Edge`, `Vis` -- **Background Controls**: `Segmentation` -- **Weights**: - - `w_edge(FG) = 1` - - `w_vis(FG) = 1` - - `w_seg(BG) = 1` - - All other weights = 0 - -### Setting 2 (fg_edge_bg_seg): Preserve Only Shape of the Robot (foreground) -- **Foreground Controls**: `Edge` -- **Background Controls**: `Segmentation` -- **Weights**: - - `w_edge(FG) = 1` - - `w_seg(BG) = 1` - - All other weights = 0 - -## Step-by-Step Instructions - -### Step 1: Generate Spatial-Temporal Weights - -This script extracts foreground (robot) and background information from semantic segmentation data. It processes per-frame segmentation masks and color-to-class mappings to generate spatial-temporal weight matrices for each control modality based on the selected setting. - -#### Input Requirements: -- A `segmentation` folder containing per-frame segmentation masks in PNG format -- A `segmentation_label` folder containing color-to-class mapping JSON files for each frame, for example: - ```json - { - "(29, 0, 0, 255)": { - "class": "gripper0_right_r_palm_vis" - }, - "(31, 0, 0, 255)": { - "class": "gripper0_right_R_thumb_proximal_base_link_vis" - }, - "(33, 0, 0, 255)": { - "class": "gripper0_right_R_thumb_proximal_link_vis" - } - } - ``` -- An input video file - -Here is an example input format: -[Example input directory](https://github.com/google-deepmind/cosmos/tree/main/assets/robot_augmentation_example/example1) - -#### Usage - -```bash -PYTHONPATH=$(pwd) python cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py \ - --setting setting1 \ - --robot-keywords world_robot gripper robot \ - --input-dir assets/robot_augmentation_example \ - --output-dir outputs/robot_augmentation_example -``` - -#### Parameters: - -* `--setting`: Weight setting to use (choices: 'setting1', 'setting2', default: 'setting1') - * setting1: Emphasizes robot in visual and edge features (vis: 1.0 foreground, edge: 1.0 foreground, seg: 1.0 background) - * setting2: Emphasizes robot only in edge features (edge: 1.0 foreground, seg: 1.0 background) - -* `--input-dir`: Input directory containing example folders - * Default: 'assets/robot_augmentation_example' - -* `--output-dir`: Output directory for weight matrices - * Default: 'outputs/robot_augmentation_example' - -* `--robot-keywords`: Keywords used to identify robot classes - * Default: ["world_robot", "gripper", "robot"] - * Any semantic class containing these keywords will be treated as robot foreground - -### Step 2: Run Cosmos-Transfer1 Inference - -Use the generated spatial-temporal weight matrices to perform video augmentation with the proper controls. - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" - -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 \ -cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/robot_example_spatial_temporal_setting1 \ - --controlnet_specs assets/robot_augmentation_example/example1/inference_cosmos_transfer1_robot_spatiotemporal_weights.json \ - --offload_text_encoder_model \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` - -- Augmented videos are saved in `outputs/robot_example_spatial_temporal_setting1/` - -## Input Outputs Example - -Input video: - - - -You can run multiple times with different prompts (e.g., `assets/robot_augmentation_example/example1/example1_prompts.json`), and you can get different augmentation results: - - diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py deleted file mode 100644 index 2cc6212464f60ae72d9d32d172222b38b4079ddc..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py +++ /dev/null @@ -1,577 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script processes segmentation results for each video frame saved as JSON files and generates a spatial-temporal weight matrix saved as a .pt file. -# The input JSON files contain segmentation information for each frame, and the output .pt file represents the spatial-temporal weight matrix for the video. - -import argparse -import glob -import json -import logging -import os -import re -from collections import defaultdict - -import cv2 -import numpy as np -import torch -from tqdm import tqdm - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - - -# Class to manage different weight settings -class WeightSettings: - """Class to manage different weight settings for the features""" - - @staticmethod - def get_settings(setting_name): - """Get weight settings by name - - Args: - setting_name (str): Name of the setting - - Returns: - dict: Dictionary with weights for each feature - """ - settings = { - # Default setting: Emphasize robot in all features - "fg_vis_edge_bg_seg": { - "depth": {"foreground": 0.0, "background": 0.0}, - "vis": {"foreground": 1.0, "background": 0.0}, - "edge": {"foreground": 1.0, "background": 0.0}, - "seg": {"foreground": 0.0, "background": 1.0}, - }, - "fg_edge_bg_seg": { - "depth": {"foreground": 0.0, "background": 0.0}, - "vis": {"foreground": 0.0, "background": 0.0}, - "edge": {"foreground": 1.0, "background": 0.0}, - "seg": {"foreground": 0.0, "background": 1.0}, - }, - } - - if setting_name not in settings: - logger.warning(f"Setting '{setting_name}' not found. Using default.") - return settings["fg_vis_edge_bg_seg"] - - return settings[setting_name] - - @staticmethod - def list_settings(): - """List all available settings - - Returns: - list: List of setting names - """ - return ["fg_vis_edge_bg_seg", "fg_edge_bg_seg"] - - -def get_video_info(video_path): - """Get video dimensions and frame count""" - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise ValueError(f"Could not open video file: {video_path}") - - width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - fps = cap.get(cv2.CAP_PROP_FPS) - - cap.release() - return width, height, frame_count, fps - - -def parse_color_key(color_key): - """Parse a color key string into an RGB tuple - - Args: - color_key (str): Color key string in the format "(r,g,b,a)" or similar - - Returns: - tuple: RGB tuple (r, g, b) - """ - # Extract numbers using regex to handle different formats - numbers = re.findall(r"\d+", color_key) - if len(numbers) >= 3: - r, g, b = map(int, numbers[:3]) - return (r, g, b) - else: - raise ValueError(f"Invalid color key format: {color_key}") - - -def save_visualization(mask, frame_num, feature_name, viz_dir): - """Save a visualization of the binary mask - - Args: - mask (numpy.ndarray): The mask (values 0 or 255) - frame_num (int): The frame number - feature_name (str): The name of the feature (depth, vis, edge, seg) - viz_dir (str): Directory to save visualizations - """ - # Simply save the binary mask directly - output_path = os.path.join(viz_dir, f"{feature_name}_frame_{frame_num:06d}.png") - cv2.imwrite(output_path, mask) - logger.info(f"Saved binary visualization to {output_path}") - - -def process_segmentation_files( - segmentation_dir, - output_dir, - viz_dir, - video_path=None, - weights_dict=None, - setting_name="fg_vis_edge_bg_seg", - robot_keywords=None, -): - """Process all segmentation JSON files and create weight matrices - - Args: - segmentation_dir (str): Directory containing segmentation JSON files - output_dir (str): Directory to save weight matrices - viz_dir (str): Directory to save visualizations - video_path (str, optional): Path to the video file. Defaults to None. - weights_dict (dict, optional): Dictionary with weights for each feature. - Format: { - 'depth': {'foreground': float, 'background': float}, - 'vis': {'foreground': float, 'background': float}, - 'edge': {'foreground': float, 'background': float}, - 'seg': {'foreground': float, 'background': float} - } - Values should be in range 0-1. Defaults to None. - setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg (setting1)'. - robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to ["robot"]. - """ - - # Set default robot keywords if not provided - if robot_keywords is None: - robot_keywords = ["robot"] - - # Get all JSON files - json_files = sorted(glob.glob(os.path.join(segmentation_dir, "*.json"))) - logger.info(f"Found {len(json_files)} JSON files") - - if len(json_files) == 0: - raise ValueError(f"No JSON files found in {segmentation_dir}") - - # For example directories, check for PNG files - png_dir = os.path.join(os.path.dirname(segmentation_dir), "segmentation") - png_files = [] - if os.path.exists(png_dir): - png_files = sorted(glob.glob(os.path.join(png_dir, "*.png"))) - logger.info(f"Found {len(png_files)} PNG files in segmentation directory") - - # Step 1: Create a unified color-to-class mapping from all JSON files - logger.info("Creating unified color-to-class mapping...") - rgb_to_class = {} - rgb_to_is_robot = {} - - for json_file in tqdm(json_files, desc="Processing JSON files for unified mapping"): - with open(json_file, "r") as f: - json_data = json.load(f) - - for color_key, data in json_data.items(): - color = parse_color_key(color_key) - class_name = data["class"] - - # Store RGB color for matching - rgb_to_class[color] = class_name - rgb_to_is_robot[color] = any(keyword in class_name for keyword in robot_keywords) - - # Print statistics about the unified color mapping - robot_colors = [color for color, is_robot in rgb_to_is_robot.items() if is_robot] - logger.info(f"Unified mapping: Found {len(robot_colors)} robot colors out of {len(rgb_to_is_robot)} total colors") - if robot_colors: - logger.info(f"Robot classes: {[rgb_to_class[color] for color in robot_colors]}") - - # Convert color mapping to arrays for vectorized operations - colors = list(rgb_to_is_robot.keys()) - color_array = np.array(colors) - is_robot_array = np.array([rgb_to_is_robot[color] for color in colors], dtype=bool) - - # If we have PNG files, get dimensions from the first PNG - if png_files: - # Get dimensions from the first PNG file - first_png = cv2.imread(png_files[0]) - if first_png is None: - raise ValueError(f"Could not read PNG file: {png_files[0]}") - - height, width = first_png.shape[:2] - frame_count = len(png_files) - - # Match frame numbers between JSON and PNG files to ensure correct correspondence - json_frame_nums = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in json_files] - png_frame_nums = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in png_files] - - # Find common frames between JSON and PNG files - common_frames = sorted(set(json_frame_nums).intersection(set(png_frame_nums))) - logger.info(f"Found {len(common_frames)} common frames between JSON and PNG files") - - if len(common_frames) == 0: - raise ValueError("No matching frames found between JSON and PNG files") - - # Create maps to easily look up files by frame number - json_map = {int(os.path.basename(f).split("_")[-1].split(".")[0]): f for f in json_files} - png_map = {int(os.path.basename(f).split("_")[-1].split(".")[0]): f for f in png_files} - - # Create new lists with only matching files - json_files = [json_map[frame] for frame in common_frames if frame in json_map] - png_files = [png_map[frame] for frame in common_frames if frame in png_map] - num_frames = len(json_files) - - logger.info(f"Using PNG dimensions: {width}x{height}, processing {num_frames} frames") - else: - # Get video information if no PNG files available - try: - width, height, frame_count, fps = get_video_info(video_path) - logger.info(f"Video dimensions: {width}x{height}, {frame_count} frames, {fps} fps") - num_frames = min(len(json_files), frame_count) - except Exception as e: - logger.warning(f"Warning: Could not get video information: {e}") - # Use a default size if we can't get the video info - width, height = 640, 480 - num_frames = len(json_files) - logger.info(f"Using default dimensions: {width}x{height}, {num_frames} frames") - - # Initialize weight tensors - depth_weights = torch.zeros((num_frames, height, width)) - vis_weights = torch.zeros((num_frames, height, width)) - edge_weights = torch.zeros((num_frames, height, width)) - seg_weights = torch.zeros((num_frames, height, width)) - - # Process frames - if png_files: - # Process PNG files directly - for i, (json_file, png_file) in enumerate(zip(json_files, png_files)): - # Get frame number from filename - frame_num = int(os.path.basename(json_file).split("_")[-1].split(".")[0]) - - # Read the corresponding PNG file - frame = cv2.imread(png_file) - - if frame is None: - logger.warning(f"Warning: Could not read frame {i} from PNG. Using blank frame.") - frame = np.zeros((height, width, 3), dtype=np.uint8) - - # Convert frame to RGB - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - - # Calculate total pixels - total_pixels = height * width - - # Vectorized approach for finding nearest colors - # Convert frame_rgb to a 2D array of shape (height*width, 3) - pixels = frame_rgb.reshape(-1, 3) - - # Calculate distances between each pixel and each color (vectorized) - # This creates a matrix of shape (height*width, num_colors) - distances = np.sqrt(np.sum((pixels[:, np.newaxis, :] - color_array[np.newaxis, :, :]) ** 2, axis=2)) - - # Find the index of the nearest color for each pixel - nearest_color_indices = np.argmin(distances, axis=1) - - # Get the is_robot value for each pixel based on its nearest color - pixel_is_robot = is_robot_array[nearest_color_indices] - - # Reshape back to image dimensions - pixel_is_robot_2d = pixel_is_robot.reshape(height, width) - - # Count robot and matched pixels - robot_pixel_count = np.sum(pixel_is_robot) - matched_pixel_count = pixels.shape[0] # All pixels are matched now - - # Create masks based on the is_robot classification - depth_mask = np.where( - pixel_is_robot_2d, weights_dict["depth"]["foreground"], weights_dict["depth"]["background"] - ) - - vis_mask = np.where(pixel_is_robot_2d, weights_dict["vis"]["foreground"], weights_dict["vis"]["background"]) - - edge_mask = np.where( - pixel_is_robot_2d, weights_dict["edge"]["foreground"], weights_dict["edge"]["background"] - ) - - seg_mask = np.where(pixel_is_robot_2d, weights_dict["seg"]["foreground"], weights_dict["seg"]["background"]) - - # Create visualization mask - visualization_mask = np.zeros((height, width), dtype=np.uint8) - visualization_mask[pixel_is_robot_2d] = 255 - - # Log statistics - robot_percentage = (robot_pixel_count / total_pixels) * 100 - matched_percentage = (matched_pixel_count / total_pixels) * 100 - logger.info(f"Frame {frame_num}: {robot_pixel_count} robot pixels ({robot_percentage:.2f}%)") - logger.info(f"Frame {frame_num}: {matched_pixel_count} matched pixels ({matched_percentage:.2f}%)") - - # Save visualizations for this frame - save_visualization(visualization_mask, frame_num, "segmentation", viz_dir) - - # Store the masks in the weight tensors - depth_weights[i] = torch.from_numpy(depth_mask) - vis_weights[i] = torch.from_numpy(vis_mask) - edge_weights[i] = torch.from_numpy(edge_mask) - seg_weights[i] = torch.from_numpy(seg_mask) - else: - # Use video frames if available - try: - # Open the segmentation video - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise ValueError(f"Could not open video file: {video_path}") - - # Process each frame using the unified color mapping - for i, json_file in enumerate(tqdm(json_files[:num_frames], desc="Processing frames")): - # Get frame number from filename - frame_num = int(os.path.basename(json_file).split("_")[-1].split(".")[0]) - - # Read the corresponding frame from the video - cap.set(cv2.CAP_PROP_POS_FRAMES, i) - ret, frame = cap.read() - - if not ret: - logger.warning(f"Warning: Could not read frame {i} from video. Using blank frame.") - frame = np.zeros((height, width, 3), dtype=np.uint8) - - # Convert frame to RGB - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - - # Calculate total pixels - total_pixels = height * width - - # Vectorized approach for finding nearest colors - pixels = frame_rgb.reshape(-1, 3) - distances = np.sqrt(np.sum((pixels[:, np.newaxis, :] - color_array[np.newaxis, :, :]) ** 2, axis=2)) - nearest_color_indices = np.argmin(distances, axis=1) - pixel_is_robot = is_robot_array[nearest_color_indices] - pixel_is_robot_2d = pixel_is_robot.reshape(height, width) - - # Count robot and matched pixels - robot_pixel_count = np.sum(pixel_is_robot) - matched_pixel_count = pixels.shape[0] - - # Create masks based on the is_robot classification - depth_mask = np.where( - pixel_is_robot_2d, weights_dict["depth"]["foreground"], weights_dict["depth"]["background"] - ) - vis_mask = np.where( - pixel_is_robot_2d, weights_dict["vis"]["foreground"], weights_dict["vis"]["background"] - ) - edge_mask = np.where( - pixel_is_robot_2d, weights_dict["edge"]["foreground"], weights_dict["edge"]["background"] - ) - seg_mask = np.where( - pixel_is_robot_2d, weights_dict["seg"]["foreground"], weights_dict["seg"]["background"] - ) - - # Create visualization mask - visualization_mask = np.zeros((height, width), dtype=np.uint8) - visualization_mask[pixel_is_robot_2d] = 255 - - # Log statistics - robot_percentage = (robot_pixel_count / total_pixels) * 100 - matched_percentage = (matched_pixel_count / total_pixels) * 100 - logger.info(f"Frame {frame_num}: {robot_pixel_count} robot pixels ({robot_percentage:.2f}%)") - logger.info(f"Frame {frame_num}: {matched_pixel_count} matched pixels ({matched_percentage:.2f}%)") - - # Save visualizations for this frame - save_visualization(visualization_mask, frame_num, "segmentation", viz_dir) - - # Store the masks in the weight tensors - depth_weights[i] = torch.from_numpy(depth_mask) - vis_weights[i] = torch.from_numpy(vis_mask) - edge_weights[i] = torch.from_numpy(edge_mask) - seg_weights[i] = torch.from_numpy(seg_mask) - - # Close the video capture - cap.release() - except Exception as e: - logger.warning(f"Warning: Error processing video: {e}") - logger.warning("Cannot process this example without proper frame data.") - raise ValueError(f"Cannot process example without frame data: {e}") - - # Save weight tensors - # Convert weights to half precision (float16) to reduce file size - depth_weights_half = depth_weights.to(torch.float16) - vis_weights_half = vis_weights.to(torch.float16) - edge_weights_half = edge_weights.to(torch.float16) - seg_weights_half = seg_weights.to(torch.float16) - - # Save the half precision tensors - torch.save(depth_weights_half, os.path.join(output_dir, "depth_weights.pt")) - torch.save(vis_weights_half, os.path.join(output_dir, "vis_weights.pt")) - torch.save(edge_weights_half, os.path.join(output_dir, "edge_weights.pt")) - torch.save(seg_weights_half, os.path.join(output_dir, "seg_weights.pt")) - - logger.info(f"Saved weight matrices to {output_dir}") - logger.info(f"Weight matrix shape: {depth_weights_half.shape}, dtype: {depth_weights_half.dtype}") - logger.info(f"Saved visualizations to {viz_dir}") - - return output_dir, viz_dir - - -def process_all_examples(input_dir, output_dir, setting_name="fg_vis_edge_bg_seg", robot_keywords=None): - """Process all example directories in the provided input directory - - Args: - input_dir (str): Input directory containing example folders - output_dir (str): Output directory for weight matrices - setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg'. - robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to None. - """ - # Find all example directories - if not os.path.exists(input_dir): - logger.error(f"Input directory not found: {input_dir}") - return [] - - # List example directories - examples = [d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))] - examples = sorted(examples) - - if not examples: - logger.warning("No example directories found.") - return [] - - # Print found examples - logger.info(f"Found {len(examples)} example directories:") - for example in examples: - logger.info(f" - {example}") - - # Store processing results - results = [] - - # Process each example - for example in examples: - try: - logger.info(f"\nProcessing {example}...") - - # Process this example with custom directories - out_dir, viz_dir = process_example_with_dirs(example, input_dir, output_dir, setting_name, robot_keywords) - results.append((example, out_dir, viz_dir)) - - logger.info(f"Results for {example} saved to:") - logger.info(f" Weight matrices: {out_dir}") - logger.info(f" Visualizations: {viz_dir}") - - except Exception as e: - logger.error(f"Error processing {example}: {e}") - - logger.info("\nAll examples processed.") - return results - - -# Process a specific example with custom input and output directories -def process_example_with_dirs( - example_name, input_dir, output_dir, setting_name="fg_vis_edge_bg_seg", robot_keywords=None -): - """Process a specific example with custom input and output directories - - Args: - example_name (str): Name of the example directory - input_dir (str): Path to input directory containing example folders - output_dir (str): Path to output directory for weight matrices - setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg'. - robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to None. - """ - # Create paths for this example - example_dir = os.path.join(input_dir, example_name) - segmentation_dir = os.path.join(example_dir, "segmentation_label") - video_path = os.path.join(example_dir, "segmentation.mp4") - - # Create output directories - example_output_dir = os.path.join(output_dir, example_name) - viz_dir = os.path.join(example_output_dir, "visualizations") - - # Check if weight files already exist - depth_weights_path = os.path.join(example_output_dir, "depth_weights.pt") - if os.path.exists(depth_weights_path): - logger.info(f"Weight files already exist for {example_name}, skipping processing") - return example_output_dir, viz_dir - - # Create output directories if they don't exist - os.makedirs(example_output_dir, exist_ok=True) - os.makedirs(viz_dir, exist_ok=True) - - # Get weight settings - weights_dict = WeightSettings.get_settings(setting_name) - - # Process this example directly with paths - return process_segmentation_files( - segmentation_dir=segmentation_dir, - output_dir=example_output_dir, - viz_dir=viz_dir, - video_path=video_path, - weights_dict=weights_dict, - setting_name=setting_name, - robot_keywords=robot_keywords, - ) - - -if __name__ == "__main__": - # Parse command-line arguments - parser = argparse.ArgumentParser( - description="Process segmentation files to generate spatial-temporal weight matrices" - ) - parser.add_argument( - "--setting", - type=str, - default="fg_vis_edge_bg_seg", - choices=WeightSettings.list_settings(), - help="Weight setting to use (default: fg_vis_edge_bg_seg (setting1), fg_edge_bg_seg (setting2))", - ) - parser.add_argument( - "--input-dir", - type=str, - default="assets/robot_augmentation_example", - help="Input directory containing example folders", - ) - parser.add_argument( - "--output-dir", - type=str, - default="outputs/robot_augmentation_example", - help="Output directory for weight matrices", - ) - parser.add_argument( - "--robot-keywords", - type=str, - nargs="+", - default=["world_robot", "gripper", "robot"], - help="Keywords used to identify robot classes (default: world_robot gripper robot)", - ) - parser.add_argument( - "--log-level", - type=str, - default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Set the logging level", - ) - args = parser.parse_args() - - # Set logging level from command line argument - logger.setLevel(getattr(logging, args.log_level)) - - # Get directories from arguments - input_dir = args.input_dir - output_dir = args.output_dir - setting_name = args.setting - robot_keywords = args.robot_keywords - - logger.info(f"Using input directory: {input_dir}") - logger.info(f"Using output directory: {output_dir}") - logger.info(f"Using weight setting: {setting_name}") - logger.info(f"Using robot keywords: {robot_keywords}") - - # Process all examples with the provided input and output directories - process_all_examples(input_dir, output_dir, setting_name, robot_keywords) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_model.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_model.py deleted file mode 100644 index 40dd41d77cdb0e575b10754db38c7155232d2b9b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_model.py +++ /dev/null @@ -1,392 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -import numpy as np -import pycocotools.mask as mask_util -import torch - -from cosmos_transfer1.utils import log - -sys.path.append("cosmos_transfer1/auxiliary") - -import tempfile - -from PIL import Image -from sam2.sam2_video_predictor import SAM2VideoPredictor -from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor - -from cosmos_transfer1.auxiliary.sam2.sam2_utils import ( - capture_fps, - convert_masks_to_frames, - generate_tensor_from_images, - video_to_frames, - write_video, -) -from cosmos_transfer1.checkpoints import GROUNDING_DINO_MODEL_CHECKPOINT, SAM2_MODEL_CHECKPOINT - - -def rle_encode(mask: np.ndarray) -> dict: - """ - Encode a boolean mask (of shape (T, H, W)) using the pycocotools RLE format, - matching the format of eff_segmentation.RleMaskSAMv2 (from Yotta). - - The procedure is: - 1. Convert the mask to a numpy array in Fortran order. - 2. Reshape the array to (-1, 1) (i.e. flatten in Fortran order). - 3. Call pycocotools.mask.encode on the reshaped array. - 4. Return a dictionary with the encoded data and the original mask shape. - """ - mask = np.array(mask, order="F") - # Reshape the mask to (-1, 1) in Fortran order and encode it. - encoded = mask_util.encode(np.array(mask.reshape(-1, 1), order="F")) - return {"data": encoded, "mask_shape": mask.shape} - - -class VideoSegmentationModel: - def __init__(self, **kwargs): - """Initialize the model and load all required components.""" - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Initialize SAM2 predictor - self.sam2_predictor = SAM2VideoPredictor.from_pretrained(SAM2_MODEL_CHECKPOINT).to(self.device) - - # Initialize GroundingDINO for text-based detection - self.grounding_model_name = kwargs.get("grounding_model", GROUNDING_DINO_MODEL_CHECKPOINT) - self.processor = AutoProcessor.from_pretrained(self.grounding_model_name) - self.grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(self.grounding_model_name).to( - self.device - ) - - def get_boxes_from_text(self, image_path, text_prompt): - """Get bounding boxes (and labels) from a text prompt using GroundingDINO.""" - image = Image.open(image_path).convert("RGB") - - inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device) - - with torch.no_grad(): - outputs = self.grounding_model(**inputs) - - # Try with initial thresholds. - results = self.processor.post_process_grounded_object_detection( - outputs, - inputs.input_ids, - box_threshold=0.15, - text_threshold=0.25, - target_sizes=[image.size[::-1]], - ) - - boxes = results[0]["boxes"].cpu().numpy() - scores = results[0]["scores"].cpu().numpy() - labels = results[0].get("labels", None) - if len(boxes) == 0: - print(f"No boxes detected for prompt: '{text_prompt}'. Trying with lower thresholds...") - results = self.processor.post_process_grounded_object_detection( - outputs, - inputs.input_ids, - box_threshold=0.1, - text_threshold=0.1, - target_sizes=[image.size[::-1]], - ) - boxes = results[0]["boxes"].cpu().numpy() - scores = results[0]["scores"].cpu().numpy() - labels = results[0].get("labels", None) - - if len(boxes) > 0: - print(f"Found {len(boxes)} boxes with scores: {scores}") - # Sort boxes by confidence score in descending order - sorted_indices = np.argsort(scores)[::-1] - boxes = boxes[sorted_indices] - scores = scores[sorted_indices] - if labels is not None: - labels = np.array(labels)[sorted_indices] - else: - print("Still no boxes detected. Consider adjusting the prompt or using box/points mode.") - - return {"boxes": boxes, "labels": labels, "scores": scores} - - def visualize_frame(self, frame_idx, obj_ids, masks, video_dir, frame_names, visualization_data, save_dir=None): - """ - Process a single frame: load the image, apply the segmentation mask to black out the - detected object(s), and save both the masked frame and the binary mask image. - """ - # Load the frame. - frame_path = os.path.join(video_dir, frame_names[frame_idx]) - img = Image.open(frame_path).convert("RGB") - image_np = np.array(img) - - # Combine masks from the detection output. - if isinstance(masks, torch.Tensor): - mask_np = (masks[0] > 0.0).cpu().numpy().astype(bool) - combined_mask = mask_np - elif isinstance(masks, dict): - first_mask = next(iter(masks.values())) - combined_mask = np.zeros_like(first_mask, dtype=bool) - for m in masks.values(): - combined_mask |= m - else: - combined_mask = None - - if combined_mask is not None: - combined_mask = np.squeeze(combined_mask) - - # If the mask shape doesn't match the image, resize it. - if combined_mask.shape != image_np.shape[:2]: - mask_img = Image.fromarray((combined_mask.astype(np.uint8)) * 255) - mask_img = mask_img.resize((image_np.shape[1], image_np.shape[0]), resample=Image.NEAREST) - combined_mask = np.array(mask_img) > 127 - - # Black out the detected region. - image_np[combined_mask] = 0 - - mask_image = (combined_mask.astype(np.uint8)) * 255 - mask_pil = Image.fromarray(mask_image) - - if save_dir: - seg_frame_path = os.path.join(save_dir, f"frame_{frame_idx}_segmented.png") - seg_pil = Image.fromarray(image_np) - seg_pil.save(seg_frame_path) - if combined_mask is not None: - mask_save_path = os.path.join(save_dir, f"frame_{frame_idx}_mask.png") - mask_pil.save(mask_save_path) - - def sample(self, **kwargs): - """ - Main sampling function for video segmentation. - Returns a list of detections in which each detection contains a phrase and - an RLE-encoded segmentation mask (matching the output of the Grounded SAM model). - """ - video_dir = kwargs.get("video_dir", "") - mode = kwargs.get("mode", "points") - input_data = kwargs.get("input_data", None) - save_dir = kwargs.get("save_dir", None) - visualize = kwargs.get("visualize", False) - - # Get frame names (expecting frames named as numbers with .jpg/.jpeg extension). - frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]] - frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) - - with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): - state = self.sam2_predictor.init_state(video_path=video_dir) - - ann_frame_idx = 0 - ann_obj_id = 1 - boxes = None - points = None - labels = None - box = None - - visualization_data = {"mode": mode, "points": None, "labels": None, "box": None, "boxes": None} - - if input_data is not None: - if mode == "points": - points = input_data.get("points") - labels = input_data.get("labels") - frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( - inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels - ) - visualization_data["points"] = points - visualization_data["labels"] = labels - elif mode == "box": - box = input_data.get("box") - frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( - inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=box - ) - visualization_data["box"] = box - elif mode == "prompt": - text = input_data.get("text") - first_frame_path = os.path.join(video_dir, frame_names[0]) - gd_results = self.get_boxes_from_text(first_frame_path, text) - boxes = gd_results["boxes"] - labels_out = gd_results["labels"] - scores = gd_results["scores"] - log.info(f"scores: {scores}") - if len(boxes) > 0: - legacy_mask = kwargs.get("legacy_mask", False) - if legacy_mask: - # Use only the highest confidence box for legacy mask - log.info(f"using legacy_mask: {legacy_mask}") - frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( - inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=boxes[0] - ) - # Update boxes and labels after processing - boxes = boxes[:1] - if labels_out is not None: - labels_out = labels_out[:1] - else: - log.info(f"using new_mask: {legacy_mask}") - for object_id, (box, label) in enumerate(zip(boxes, labels_out)): - frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box( - inference_state=state, frame_idx=ann_frame_idx, obj_id=object_id, box=box - ) - visualization_data["boxes"] = boxes - self.grounding_labels = [str(lbl) for lbl in labels_out] if labels_out is not None else [text] - else: - print("No boxes detected. Exiting.") - return [] # Return empty list if no detections - - if visualize: - self.visualize_frame( - frame_idx=ann_frame_idx, - obj_ids=obj_ids, - masks=masks, - video_dir=video_dir, - frame_names=frame_names, - visualization_data=visualization_data, - save_dir=save_dir, - ) - - video_segments = {} # keys: frame index, values: {obj_id: mask} - for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_predictor.propagate_in_video(state): - video_segments[out_frame_idx] = { - out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) - } - - # For propagated frames, visualization_data is not used. - if visualize: - propagate_visualization_data = { - "mode": mode, - "points": None, - "labels": None, - "box": None, - "boxes": None, - } - self.visualize_frame( - frame_idx=out_frame_idx, - obj_ids=out_obj_ids, - masks=video_segments[out_frame_idx], - video_dir=video_dir, - frame_names=frame_names, - visualization_data=propagate_visualization_data, - save_dir=save_dir, - ) - - # --- Post-process video_segments to produce a list of detections --- - if len(video_segments) == 0: - return [] - - first_frame_path = os.path.join(video_dir, frame_names[0]) - first_frame = np.array(Image.open(first_frame_path).convert("RGB")) - original_shape = first_frame.shape[:2] # (height, width) - - object_masks = {} # key: obj_id, value: list of 2D boolean masks - sorted_frame_indices = sorted(video_segments.keys()) - for frame_idx in sorted_frame_indices: - segments = video_segments[frame_idx] - for obj_id, mask in segments.items(): - mask = np.squeeze(mask) - if mask.ndim != 2: - print(f"Warning: Unexpected mask shape {mask.shape} for object {obj_id} in frame {frame_idx}.") - continue - - if mask.shape != original_shape: - mask_img = Image.fromarray(mask.astype(np.uint8) * 255) - mask_img = mask_img.resize((original_shape[1], original_shape[0]), resample=Image.NEAREST) - mask = np.array(mask_img) > 127 - - if obj_id not in object_masks: - object_masks[obj_id] = [] - object_masks[obj_id].append(mask) - - detections = [] - for obj_id, mask_list in object_masks.items(): - mask_stack = np.stack(mask_list, axis=0) # shape: (T, H, W) - # Use our new rle_encode (which now follows the eff_segmentation.RleMaskSAMv2 format) - rle = rle_encode(mask_stack) - if mode == "prompt" and hasattr(self, "grounding_labels"): - phrase = self.grounding_labels[0] - else: - phrase = input_data.get("text", "") - detection = {"phrase": phrase, "segmentation_mask_rle": rle} - detections.append(detection) - - return detections - - @staticmethod - def parse_points(points_str): - """Parse a string of points into a numpy array. - Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150'). - """ - points = [] - for point in points_str.split(";"): - coords = point.split(",") - if len(coords) != 2: - continue - points.append([float(coords[0]), float(coords[1])]) - return np.array(points, dtype=np.float32) - - @staticmethod - def parse_labels(labels_str): - """Parse a comma-separated string of labels into a numpy array.""" - return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32) - - @staticmethod - def parse_box(box_str): - """Parse a comma-separated string of 4 box coordinates into a numpy array.""" - return np.array([float(x) for x in box_str.split(",")], dtype=np.float32) - - def __call__( - self, - input_video, - output_video=None, - output_tensor=None, - prompt=None, - box=None, - points=None, - labels=None, - weight_scaler=None, - binarize_video=False, - legacy_mask=False, - ): - log.info( - f"Processing video: {input_video} to generate segmentation video: {output_video} segmentation tensor: {output_tensor}" - ) - assert os.path.exists(input_video) - - # Prepare input data based on the selected mode. - if points is not None: - mode = "points" - input_data = {"points": self.parse_points(points), "labels": self.parse_labels(labels)} - elif box is not None: - mode = "box" - input_data = {"box": self.parse_box(box)} - elif prompt is not None: - mode = "prompt" - input_data = {"text": prompt} - - with tempfile.TemporaryDirectory() as temp_input_dir: - fps = capture_fps(input_video) - video_to_frames(input_video, temp_input_dir) - with tempfile.TemporaryDirectory() as temp_output_dir: - masks = self.sample( - video_dir=temp_input_dir, - mode=mode, - input_data=input_data, - save_dir=str(temp_output_dir), - visualize=True, - legacy_mask=legacy_mask, - ) - if output_video: - os.makedirs(os.path.dirname(output_video), exist_ok=True) - frames = convert_masks_to_frames(masks) - if binarize_video: - frames = np.any(frames > 0, axis=-1).astype(np.uint8) * 255 - write_video(frames, output_video, fps) - if output_tensor: - generate_tensor_from_images( - temp_output_dir, output_tensor, fps, "mask", weight_scaler=weight_scaler - ) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_pipeline.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_pipeline.py deleted file mode 100644 index 41be7e7ea9543671b8a023d75952f8446abf08ca..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_pipeline.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import tempfile - -import numpy as np - -from cosmos_transfer1.auxiliary.sam2.sam2_model import VideoSegmentationModel -from cosmos_transfer1.auxiliary.sam2.sam2_utils import ( - capture_fps, - generate_tensor_from_images, - generate_video_from_images, - video_to_frames, -) - - -def parse_args(): - parser = argparse.ArgumentParser(description="Video Segmentation using SAM2") - parser.add_argument("--input_video", type=str, required=True, help="Path to input video file") - parser.add_argument( - "--output_video", type=str, default="./outputs/output_video.mp4", help="Path to save the output video" - ) - parser.add_argument( - "--output_tensor", type=str, default="./outputs/output_tensor.pt", help="Path to save the output tensor" - ) - parser.add_argument( - "--mode", type=str, choices=["points", "box", "prompt"], default="points", help="Segmentation mode" - ) - parser.add_argument("--prompt", type=str, help="Text prompt for prompt mode") - parser.add_argument( - "--grounding_model_path", - type=str, - default="IDEA-Research/grounding-dino-tiny", - help="Local directory for GroundingDINO model files", - ) - parser.add_argument( - "--points", - type=str, - default="200,300", - help="Comma-separated point coordinates for points mode (e.g., '200,300' or for multiple points use ';' as a separator, e.g., '200,300;100,150').", - ) - parser.add_argument( - "--labels", - type=str, - default="1", - help="Comma-separated labels for points mode (e.g., '1' or '1,0' for multiple points).", - ) - parser.add_argument( - "--box", - type=str, - default="300,0,500,400", - help="Comma-separated box coordinates for box mode (e.g., '300,0,500,400').", - ) - # New flag to control visualization. - parser.add_argument("--visualize", action="store_true", help="If set, visualize segmentation frames (save images)") - return parser.parse_args() - - -def parse_points(points_str): - """Parse a string of points into a numpy array. - Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150'). - """ - points = [] - for point in points_str.split(";"): - coords = point.split(",") - if len(coords) != 2: - continue - points.append([float(coords[0]), float(coords[1])]) - return np.array(points, dtype=np.float32) - - -def parse_labels(labels_str): - """Parse a comma-separated string of labels into a numpy array.""" - return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32) - - -def parse_box(box_str): - """Parse a comma-separated string of 4 box coordinates into a numpy array.""" - return np.array([float(x) for x in box_str.split(",")], dtype=np.float32) - - -def main(): - args = parse_args() - - # Initialize the segmentation model. - model = VideoSegmentationModel(**vars(args)) - - # Prepare input data based on the selected mode. - if args.mode == "points": - input_data = {"points": parse_points(args.points), "labels": parse_labels(args.labels)} - elif args.mode == "box": - input_data = {"box": parse_box(args.box)} - elif args.mode == "prompt": - input_data = {"text": args.prompt} - - with tempfile.TemporaryDirectory() as temp_input_dir: - fps = capture_fps(args.input_video) - video_to_frames(args.input_video, temp_input_dir) - with tempfile.TemporaryDirectory() as temp_output_dir: - model.sample( - video_dir=temp_input_dir, - mode=args.mode, - input_data=input_data, - save_dir=str(temp_output_dir), - visualize=True, - ) - generate_video_from_images(temp_output_dir, args.output_video, fps) - generate_tensor_from_images(temp_output_dir, args.output_tensor, fps, "mask") - - -if __name__ == "__main__": - print("Starting video segmentation...") - main() diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_utils.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_utils.py deleted file mode 100644 index b9554fa28bef411df2a8dc03040b515554556b7d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/sam2/sam2_utils.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import time - -import cv2 -import imageio -import numpy as np -import pycocotools.mask -import torch -from natsort import natsorted -from PIL import Image -from torchvision import transforms - -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import ( - decode_partial_rle_width1, - segmentation_color_mask, -) -from cosmos_transfer1.utils import log - - -def write_video(frames, output_path, fps=30): - """ - expects a sequence of [H, W, 3] or [H, W] frames - """ - with imageio.get_writer(output_path, fps=fps, macro_block_size=8) as writer: - for frame in frames: - if len(frame.shape) == 2: # single channel - frame = frame[:, :, None].repeat(3, axis=2) - writer.append_data(frame) - - -def capture_fps(input_video_path: str): - cap = cv2.VideoCapture(input_video_path) - fps = cap.get(cv2.CAP_PROP_FPS) - return fps - - -def video_to_frames(input_loc, output_loc): - """Function to extract frames from input video file - and save them as separate frames in an output directory. - Args: - input_loc: Input video file. - output_loc: Output directory to save the frames. - Returns: - None - """ - try: - os.mkdir(output_loc) - except OSError: - pass - # Log the time - time_start = time.time() - # Start capturing the feed - cap = cv2.VideoCapture(input_loc) - # Find the number of frames - video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - print(f"Number of frames: {video_length}") - count = 0 - print("Converting video..\n") - # Start converting the video - while cap.isOpened(): - # Extract the frame - ret, frame = cap.read() - if not ret: - continue - # Write the results back to output location. - cv2.imwrite(output_loc + "/%#05d.jpg" % (count + 1), frame) - count = count + 1 - # If there are no more frames left - if count > (video_length - 1): - # Log the time again - time_end = time.time() - # Release the feed - cap.release() - # Print stats - print("Done extracting frames.\n%d frames extracted" % count) - print("It took %d seconds forconversion." % (time_end - time_start)) - break - - -# Function to generate video -def convert_masks_to_frames(masks: list, num_masks_max: int = 100): - T, H, W = shape = masks[0]["segmentation_mask_rle"]["mask_shape"] - frame_start, frame_end = 0, T - num_masks = min(num_masks_max, len(masks)) - mask_ids_select = np.arange(num_masks).tolist() - - all_masks = np.zeros((num_masks, T, H, W), dtype=np.uint8) - for idx, mid in enumerate(mask_ids_select): - mask = masks[mid] - num_byte_per_mb = 1024 * 1024 - # total number of elements in uint8 (1 byte) / num_byte_per_mb - if shape[0] * shape[1] * shape[2] / num_byte_per_mb > 256: - rle = decode_partial_rle_width1( - mask["segmentation_mask_rle"]["data"], - frame_start * shape[1] * shape[2], - frame_end * shape[1] * shape[2], - ) - partial_shape = (frame_end - frame_start, shape[1], shape[2]) - rle = rle.reshape(partial_shape) * 255 - else: - rle = pycocotools.mask.decode(mask["segmentation_mask_rle"]["data"]) - rle = rle.reshape(shape) * 255 - # Select the frames that are in the video - frame_indices = np.arange(frame_start, frame_end).tolist() - rle = np.stack([rle[i] for i in frame_indices]) - all_masks[idx] = rle - del rle - - all_masks = segmentation_color_mask(all_masks) # NTHW -> 3THW - all_masks = all_masks.transpose(1, 2, 3, 0) - return all_masks - - -def generate_video_from_images(masks: list, output_file_path: str, fps, num_masks_max: int = 100): - all_masks = convert_masks_to_frames(masks, num_masks_max) - write_video(all_masks, output_file_path, fps) - print("Video generated successfully!") - - -def generate_tensor_from_images( - image_path_str: str, output_file_path: str, fps, search_pattern: str = None, weight_scaler: float = None -): - images = list() - image_path = os.path.abspath(image_path_str) - if search_pattern is None: - images = [img for img in natsorted(os.listdir(image_path))] - else: - for img in natsorted(os.listdir(image_path)): - if img.__contains__(search_pattern): - images.append(img) - - transform = transforms.ToTensor() - image_tensors = list() - for image in images: - img_tensor = transform(Image.open(os.path.join(image_path, image))) - image_tensors.append(img_tensor.squeeze(0)) - - tensor = torch.stack(image_tensors) # [T, H, W], binary values, float - - if weight_scaler is not None: - log.info(f"scaling the tensor by the specified scale: {weight_scaler}") - tensor = tensor * weight_scaler - - log.info(f"saving tensor shape: {tensor.shape} to {output_file_path}") - torch.save(tensor, output_file_path) - - -if __name__ == "__main__": - input_loc = "cosmos_transfer1/models/sam2/assets/input_video.mp4" - output_loc = os.path.abspath(tempfile.TemporaryDirectory().name) - print(f"output_loc --- {output_loc}") - video_to_frames(input_loc, output_loc) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/image_cli.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/image_cli.py deleted file mode 100644 index e4dd00fe8f41f67e89ef80c6f272235ac32e32f5..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/image_cli.py +++ /dev/null @@ -1,188 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A CLI to run ImageTokenizer on plain images based on torch.jit. - -Usage: - python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.image_cli \ - --image_pattern 'path/to/input/folder/*.jpg' \ - --output_dir ./reconstructions \ - --checkpoint_enc ./checkpoints//encoder.jit \ - --checkpoint_dec ./checkpoints//decoder.jit - - Optionally, you can run the model in pure PyTorch mode: - python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.image_cli \ - --image_pattern 'path/to/input/folder/*.jpg' \ - --mode torch \ - --tokenizer_type CI \ - --spatial_compression 8 \ - --checkpoint_enc ./checkpoints//encoder.jit \ - --checkpoint_dec ./checkpoints//decoder.jit -""" - -import os -import sys -from argparse import ArgumentParser, Namespace -from typing import Any - -import numpy as np -from loguru import logger as logging - -from cosmos_transfer1.auxiliary.tokenizer.inference.image_lib import ImageTokenizer -from cosmos_transfer1.auxiliary.tokenizer.inference.utils import ( - get_filepaths, - get_output_filepath, - read_image, - resize_image, - write_image, -) -from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerConfigs - - -def _parse_args() -> tuple[Namespace, dict[str, Any]]: - parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.") - parser.add_argument( - "--image_pattern", - type=str, - default="path/to/images/*.jpg", - help="Glob pattern.", - ) - parser.add_argument( - "--checkpoint", - type=str, - default=None, - help="JIT full Autoencoder model filepath.", - ) - parser.add_argument( - "--checkpoint_enc", - type=str, - default=None, - help="JIT Encoder model filepath.", - ) - parser.add_argument( - "--checkpoint_dec", - type=str, - default=None, - help="JIT Decoder model filepath.", - ) - parser.add_argument( - "--tokenizer_type", - type=str, - choices=["CI", "DI"], - help="Specifies the tokenizer type.", - ) - parser.add_argument( - "--spatial_compression", - type=int, - choices=[8, 16], - default=8, - help="The spatial compression factor.", - ) - parser.add_argument( - "--mode", - type=str, - choices=["torch", "jit"], - default="jit", - help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", - ) - parser.add_argument( - "--short_size", - type=int, - default=None, - help="The size to resample inputs. None, by default.", - ) - parser.add_argument( - "--dtype", - type=str, - default="bfloat16", - help="Sets the precision. Default bfloat16.", - ) - parser.add_argument( - "--device", - type=str, - default="cuda", - help="Device for invoking the model.", - ) - parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") - parser.add_argument( - "--save_input", - action="store_true", - help="If on, the input image will be be outputed too.", - ) - args = parser.parse_args() - return args - - -logging.info("Initializes args ...") -args = _parse_args() -if args.mode == "torch" and args.tokenizer_type not in ["CI", "DI"]: - logging.error("'torch' backend requires the tokenizer_type of 'CI' or 'DI'.") - sys.exit(1) - - -def _run_eval() -> None: - """Invokes the evaluation pipeline.""" - - if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: - logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") - return - - if args.mode == "torch": - tokenizer_config = TokenizerConfigs[args.tokenizer_type].value - tokenizer_config.update(dict(spatial_compression=args.spatial_compression)) - else: - tokenizer_config = None - - logging.info( - f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." - ) - autoencoder = ImageTokenizer( - checkpoint=args.checkpoint, - checkpoint_enc=args.checkpoint_enc, - checkpoint_dec=args.checkpoint_dec, - tokenizer_config=tokenizer_config, - device=args.device, - dtype=args.dtype, - ) - - filepaths = get_filepaths(args.image_pattern) - logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.") - - for filepath in filepaths: - logging.info(f"Reading image {filepath} ...") - image = read_image(filepath) - image = resize_image(image, short_size=args.short_size) - batch_image = np.expand_dims(image, axis=0) - - logging.info("Invoking the autoencoder model in ... ") - output_image = autoencoder(batch_image)[0] - - output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) - logging.info(f"Outputing {output_filepath} ...") - write_image(output_filepath, output_image) - - if args.save_input: - ext = os.path.splitext(output_filepath)[-1] - input_filepath = output_filepath.replace(ext, "_input" + ext) - write_image(input_filepath, image) - - -@logging.catch(reraise=True) -def main() -> None: - _run_eval() - - -if __name__ == "__main__": - main() diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/image_lib.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/image_lib.py deleted file mode 100644 index 9e24049acfee903a0d7cbc4d0bae57c1265ff0ae..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/image_lib.py +++ /dev/null @@ -1,124 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A library for image tokenizers inference.""" - -from typing import Any - -import numpy as np -import torch - -from cosmos_transfer1.auxiliary.tokenizer.inference.utils import ( - load_decoder_model, - load_encoder_model, - load_model, - numpy2tensor, - pad_image_batch, - tensor2numpy, - unpad_image_batch, -) - - -class ImageTokenizer(torch.nn.Module): - def __init__( - self, - checkpoint: str = None, - checkpoint_enc: str = None, - checkpoint_dec: str = None, - tokenizer_config: dict[str, Any] = None, - device: str = "cuda", - dtype: str = "bfloat16", - ) -> None: - super().__init__() - self._device = device - self._dtype = getattr(torch, dtype) - self._full_model = ( - load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None - ) - self._enc_model = ( - load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) - if checkpoint_enc is not None - else None - ) - self._dec_model = ( - load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) - if checkpoint_dec is not None - else None - ) - - @torch.no_grad() - def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Reconstrcuts a batch of image tensors after embedding into a latent. - - Args: - input_tensor: The input image Bx3xHxW layout, range [-1..1]. - Returns: - The reconstructed tensor, layout Bx3xHxW, range [-1..1]. - """ - if self._full_model is not None: - output_tensor = self._full_model(input_tensor) - output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor - else: - output_latent = self.encode(input_tensor)[0] - output_tensor = self.decode(output_latent) - return output_tensor - - @torch.no_grad() - def decode(self, input_latent: torch.Tensor) -> torch.Tensor: - """Decodes an image from a provided latent embedding. - - Args: - input_latent: The continuous latent Bx16xhxw for CI, - or the discrete indices Bxhxw for DI. - Returns: - The output tensor in Bx3xHxW, range [-1..1]. - """ - return self._dec_model(input_latent) - - @torch.no_grad() - def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: - """Encodes an image into a latent embedding or code. - - Args: - input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. - Returns: - For continuous image (CI) tokenizer, the tuple contains: - - The latent embedding, Bx16x(h)x(w), where the compression - rate is (H/h x W/w), and channel dimension of 16. - For discrete image (DI) tokenizer, the tuple contains: - - The indices, Bx(h)x(w), from a codebook of size 64K, which - corresponds to FSQ levels of (8,8,8,5,5,5). - - The discrete code, Bx6x(h)x(w), where the compression rate is - again (H/h x W/w), and channel dimension of 6. - """ - output_latent = self._enc_model(input_tensor) - if isinstance(output_latent, torch.Tensor): - return output_latent - return output_latent[:-1] - - @torch.no_grad() - def forward(self, image: np.ndarray) -> np.ndarray: - """Reconstructs an image using a pre-trained tokenizer. - - Args: - image: The input image BxHxWxC layout, range [0..255]. - Returns: - The reconstructed image in range [0..255], layout BxHxWxC. - """ - padded_input_image, crop_region = pad_image_batch(image) - input_tensor = numpy2tensor(padded_input_image, dtype=self._dtype, device=self._device) - output_tensor = self.autoencode(input_tensor) - padded_output_image = tensor2numpy(output_tensor) - return unpad_image_batch(padded_output_image, crop_region) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/utils.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/utils.py deleted file mode 100644 index 9fe4e2e86a10011aea448dc2c02aa7dc84cc141d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/utils.py +++ /dev/null @@ -1,402 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for the inference libraries.""" - -import os -from glob import glob -from typing import Any - -import mediapy as media -import numpy as np -import torch - -from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerModels - -_DTYPE, _DEVICE = torch.bfloat16, "cuda" -_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) -_SPATIAL_ALIGN = 16 -_TEMPORAL_ALIGN = 8 - - -def load_model( - jit_filepath: str = None, - tokenizer_config: dict[str, Any] = None, - device: str = "cuda", -) -> torch.nn.Module | torch.jit.ScriptModule: - """Loads a torch.nn.Module from a filepath. - - Args: - jit_filepath: The filepath to the JIT-compiled model. - device: The device to load the model onto, default=cuda. - Returns: - The JIT compiled model loaded to device and on eval mode. - """ - if tokenizer_config is None: - return load_jit_model(jit_filepath, device) - full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) - full_model.load_state_dict(ckpts.state_dict(), strict=False) - return full_model.eval().to(device) - - -def load_encoder_model( - jit_filepath: str = None, - tokenizer_config: dict[str, Any] = None, - device: str = "cuda", -) -> torch.nn.Module | torch.jit.ScriptModule: - """Loads a torch.nn.Module from a filepath. - - Args: - jit_filepath: The filepath to the JIT-compiled model. - device: The device to load the model onto, default=cuda. - Returns: - The JIT compiled model loaded to device and on eval mode. - """ - if tokenizer_config is None: - return load_jit_model(jit_filepath, device) - full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) - encoder_model = full_model.encoder_jit() - encoder_model.load_state_dict(ckpts.state_dict(), strict=False) - return encoder_model.eval().to(device) - - -def load_decoder_model( - jit_filepath: str = None, - tokenizer_config: dict[str, Any] = None, - device: str = "cuda", -) -> torch.nn.Module | torch.jit.ScriptModule: - """Loads a torch.nn.Module from a filepath. - - Args: - jit_filepath: The filepath to the JIT-compiled model. - device: The device to load the model onto, default=cuda. - Returns: - The JIT compiled model loaded to device and on eval mode. - """ - if tokenizer_config is None: - return load_jit_model(jit_filepath, device) - full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) - decoder_model = full_model.decoder_jit() - decoder_model.load_state_dict(ckpts.state_dict(), strict=False) - return decoder_model.eval().to(device) - - -def _load_pytorch_model( - jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda" -) -> torch.nn.Module: - """Loads a torch.nn.Module from a filepath. - - Args: - jit_filepath: The filepath to the JIT-compiled model. - device: The device to load the model onto, default=cuda. - Returns: - The JIT compiled model loaded to device and on eval mode. - """ - tokenizer_name = tokenizer_config["name"] - model = TokenizerModels[tokenizer_name].value(**tokenizer_config) - ckpts = torch.jit.load(jit_filepath, map_location=device) - return model, ckpts - - -def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: - """Loads a torch.jit.ScriptModule from a filepath. - - Args: - jit_filepath: The filepath to the JIT-compiled model. - device: The device to load the model onto, default=cuda. - Returns: - The JIT compiled model loaded to device and on eval mode. - """ - model = torch.jit.load(jit_filepath, map_location=device) - return model.eval().to(device) - - -def save_jit_model( - model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None, - jit_filepath: str = None, -) -> None: - """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file. - - Args: - model: JIT compiled model loaded onto `config.checkpoint.jit.device`. - jit_filepath: The filepath to the JIT-compiled model. - """ - torch.jit.save(model, jit_filepath) - - -def get_filepaths(input_pattern) -> list[str]: - """Returns a list of filepaths from a pattern.""" - filepaths = sorted(glob(str(input_pattern))) - return list(set(filepaths)) - - -def get_output_filepath(filepath: str, output_dir: str = None) -> str: - """Returns the output filepath for the given input filepath.""" - output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions" - output_filepath = f"{output_dir}/{os.path.basename(filepath)}" - os.makedirs(output_dir, exist_ok=True) - return output_filepath - - -def read_image(filepath: str) -> np.ndarray: - """Reads an image from a filepath. - - Args: - filepath: The filepath to the image. - - Returns: - The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype. - """ - image = media.read_image(filepath) - # convert the grey scale image to RGB - # since our tokenizers always assume 3-channel RGB image - if image.ndim == 2: - image = np.stack([image] * 3, axis=-1) - # convert RGBA to RGB - if image.shape[-1] == 4: - image = image[..., :3] - return image - - -def read_video(filepath: str) -> np.ndarray: - """Reads a video from a filepath. - - Args: - filepath: The filepath to the video. - Returns: - The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype. - """ - video = media.read_video(filepath) - # convert the grey scale frame to RGB - # since our tokenizers always assume 3-channel video - if video.ndim == 3: - video = np.stack([video] * 3, axis=-1) - # convert RGBA to RGB - if video.shape[-1] == 4: - video = video[..., :3] - return video - - -def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray: - """Resizes an image to have the short side of `short_size`. - - Args: - image: The image to resize, layout HxWxC, of any range. - short_size: The size of the short side. - Returns: - The resized image. - """ - if short_size is None: - return image - height, width = image.shape[-3:-1] - if height <= width: - height_new, width_new = short_size, int(width * short_size / height + 0.5) - width_new = width_new if width_new % 2 == 0 else width_new + 1 - else: - height_new, width_new = ( - int(height * short_size / width + 0.5), - short_size, - ) - height_new = height_new if height_new % 2 == 0 else height_new + 1 - return media.resize_image(image, shape=(height_new, width_new)) - - -def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray: - """Resizes a video to have the short side of `short_size`. - - Args: - video: The video to resize, layout TxHxWxC, of any range. - short_size: The size of the short side. - Returns: - The resized video. - """ - if short_size is None: - return video - height, width = video.shape[-3:-1] - if height <= width: - height_new, width_new = short_size, int(width * short_size / height + 0.5) - width_new = width_new if width_new % 2 == 0 else width_new + 1 - else: - height_new, width_new = ( - int(height * short_size / width + 0.5), - short_size, - ) - height_new = height_new if height_new % 2 == 0 else height_new + 1 - return media.resize_video(video, shape=(height_new, width_new)) - - -def write_image(filepath: str, image: np.ndarray): - """Writes an image to a filepath.""" - return media.write_image(filepath, image) - - -def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None: - """Writes a video to a filepath.""" - return media.write_video(filepath, video, fps=fps) - - -def numpy2tensor( - input_image: np.ndarray, - dtype: torch.dtype = _DTYPE, - device: str = _DEVICE, - range_min: int = -1, -) -> torch.Tensor: - """Converts image(dtype=np.uint8) to `dtype` in range [0..255]. - - Args: - input_image: A batch of images in range [0..255], BxHxWx3 layout. - Returns: - A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype. - """ - ndim = input_image.ndim - indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1] - image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F - if range_min == -1: - image = 2.0 * image - 1.0 - return torch.from_numpy(image).to(dtype).to(device) - - -def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: - """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. - - Args: - input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. - Returns: - A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. - """ - if range_min == -1: - input_tensor = (input_tensor.float() + 1.0) / 2.0 - ndim = input_tensor.ndim - output_image = input_tensor.clamp(0, 1).cpu().numpy() - output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) - return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) - - -def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]: - """Pads a batch of images to be divisible by `spatial_align`. - - Args: - batch: The batch of images to pad, layout BxHxWx3, in any range. - align: The alignment to pad to. - Returns: - The padded batch and the crop region. - """ - height, width = batch.shape[1:3] - align = spatial_align - height_to_pad = (align - height % align) if height % align != 0 else 0 - width_to_pad = (align - width % align) if width % align != 0 else 0 - - crop_region = [ - height_to_pad >> 1, - width_to_pad >> 1, - height + (height_to_pad >> 1), - width + (width_to_pad >> 1), - ] - batch = np.pad( - batch, - ( - (0, 0), - (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), - (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), - (0, 0), - ), - mode="constant", - ) - return batch, crop_region - - -def pad_video_batch( - batch: np.ndarray, - temporal_align: int = _TEMPORAL_ALIGN, - spatial_align: int = _SPATIAL_ALIGN, -) -> tuple[np.ndarray, list[int]]: - """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`. - - Zero pad spatially. Reflection pad temporally to handle causality better. - Args: - batch: The batch of videos to pad., layout BxFxHxWx3, in any range. - align: The alignment to pad to. - Returns: - The padded batch and the crop region. - """ - num_frames, height, width = batch.shape[-4:-1] - align = spatial_align - height_to_pad = (align - height % align) if height % align != 0 else 0 - width_to_pad = (align - width % align) if width % align != 0 else 0 - - align = temporal_align - frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0 - - crop_region = [ - frames_to_pad >> 1, - height_to_pad >> 1, - width_to_pad >> 1, - num_frames + (frames_to_pad >> 1), - height + (height_to_pad >> 1), - width + (width_to_pad >> 1), - ] - batch = np.pad( - batch, - ( - (0, 0), - (0, 0), - (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), - (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), - (0, 0), - ), - mode="constant", - ) - batch = np.pad( - batch, - ( - (0, 0), - (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)), - (0, 0), - (0, 0), - (0, 0), - ), - mode="edge", - ) - return batch, crop_region - - -def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: - """Unpads video with `crop_region`. - - Args: - batch: A batch of numpy videos, layout BxFxHxWxC. - crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices. - - Returns: - np.ndarray: Cropped numpy video, layout BxFxHxWxC. - """ - assert len(crop_region) == 6, "crop_region should be len of 6." - f1, y1, x1, f2, y2, x2 = crop_region - return batch[..., f1:f2, y1:y2, x1:x2, :] - - -def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: - """Unpads image with `crop_region`. - - Args: - batch: A batch of numpy images, layout BxHxWxC. - crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices. - - Returns: - np.ndarray: Cropped numpy image, layout BxHxWxC. - """ - assert len(crop_region) == 4, "crop_region should be len of 4." - y1, x1, y2, x2 = crop_region - return batch[..., y1:y2, x1:x2, :] diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/video_cli.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/video_cli.py deleted file mode 100644 index d35b5aa7c573c64e45352f95d988787902cfd880..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/video_cli.py +++ /dev/null @@ -1,210 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A CLI to run CausalVideoTokenizer on plain videos based on torch.jit. - -Usage: - python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.video_cli \ - --video_pattern 'path/to/video/samples/*.mp4' \ - --output_dir ./reconstructions \ - --checkpoint_enc ./checkpoints//encoder.jit \ - --checkpoint_dec ./checkpoints//decoder.jit - - Optionally, you can run the model in pure PyTorch mode: - python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.video_cli \ - --video_pattern 'path/to/video/samples/*.mp4' \ - --mode=torch \ - --tokenizer_type=CV \ - --temporal_compression=4 \ - --spatial_compression=8 \ - --checkpoint_enc ./checkpoints//encoder.jit \ - --checkpoint_dec ./checkpoints//decoder.jit -""" - -import os -import sys -from argparse import ArgumentParser, Namespace -from typing import Any - -import numpy as np -from loguru import logger as logging - -from cosmos_transfer1.auxiliary.tokenizer.inference.utils import ( - get_filepaths, - get_output_filepath, - read_video, - resize_video, - write_video, -) -from cosmos_transfer1.auxiliary.tokenizer.inference.video_lib import CausalVideoTokenizer -from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerConfigs - - -def _parse_args() -> tuple[Namespace, dict[str, Any]]: - parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.") - parser.add_argument( - "--video_pattern", - type=str, - default="path/to/videos/*.mp4", - help="Glob pattern.", - ) - parser.add_argument( - "--checkpoint", - type=str, - default=None, - help="JIT full Autoencoder model filepath.", - ) - parser.add_argument( - "--checkpoint_enc", - type=str, - default=None, - help="JIT Encoder model filepath.", - ) - parser.add_argument( - "--checkpoint_dec", - type=str, - default=None, - help="JIT Decoder model filepath.", - ) - parser.add_argument( - "--tokenizer_type", - type=str, - choices=["CV", "DV"], - help="Specifies the tokenizer type.", - ) - parser.add_argument( - "--spatial_compression", - type=int, - choices=[8, 16], - default=8, - help="The spatial compression factor.", - ) - parser.add_argument( - "--temporal_compression", - type=int, - choices=[4, 8], - default=4, - help="The temporal compression factor.", - ) - parser.add_argument( - "--mode", - type=str, - choices=["torch", "jit"], - default="jit", - help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", - ) - parser.add_argument( - "--short_size", - type=int, - default=None, - help="The size to resample inputs. None, by default.", - ) - parser.add_argument( - "--temporal_window", - type=int, - default=17, - help="The temporal window to operate at a time.", - ) - parser.add_argument( - "--dtype", - type=str, - default="bfloat16", - help="Sets the precision, default bfloat16.", - ) - parser.add_argument( - "--device", - type=str, - default="cuda", - help="Device for invoking the model.", - ) - parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") - parser.add_argument( - "--output_fps", - type=float, - default=24.0, - help="Output frames-per-second (FPS).", - ) - parser.add_argument( - "--save_input", - action="store_true", - help="If on, the input video will be be outputted too.", - ) - - args = parser.parse_args() - return args - - -logging.info("Initializes args ...") -args = _parse_args() -if args.mode == "torch" and args.tokenizer_type not in ["CV", "DV"]: - logging.error("'torch' backend requires the tokenizer_type of 'CV' or 'DV'.") - sys.exit(1) - - -def _run_eval() -> None: - """Invokes JIT-compiled CausalVideoTokenizer on an input video.""" - - if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: - logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") - return - - if args.mode == "torch": - tokenizer_config = TokenizerConfigs[args.tokenizer_type].value - tokenizer_config.update(dict(spatial_compression=args.spatial_compression)) - tokenizer_config.update(dict(temporal_compression=args.temporal_compression)) - else: - tokenizer_config = None - - logging.info( - f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." - ) - autoencoder = CausalVideoTokenizer( - checkpoint=args.checkpoint, - checkpoint_enc=args.checkpoint_enc, - checkpoint_dec=args.checkpoint_dec, - tokenizer_config=tokenizer_config, - device=args.device, - dtype=args.dtype, - ) - - logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...") - filepaths = get_filepaths(args.video_pattern) - logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.") - - for filepath in filepaths: - logging.info(f"Reading video {filepath} ...") - video = read_video(filepath) - video = resize_video(video, short_size=args.short_size) - - logging.info("Invoking the autoencoder model in ... ") - batch_video = video[np.newaxis, ...] - output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0] - logging.info("Constructing output filepath ...") - output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) - logging.info(f"Outputing {output_filepath} ...") - write_video(output_filepath, output_video, fps=args.output_fps) - if args.save_input: - ext = os.path.splitext(output_filepath)[-1] - input_filepath = output_filepath.replace(ext, "_input" + ext) - write_video(input_filepath, video, fps=args.output_fps) - - -@logging.catch(reraise=True) -def main() -> None: - _run_eval() - - -if __name__ == "__main__": - main() diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/video_lib.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/video_lib.py deleted file mode 100644 index b8d9bb9642aeac2074c3e8e19d54b19155367305..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/inference/video_lib.py +++ /dev/null @@ -1,146 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A library for Causal Video Tokenizer inference.""" - -from typing import Any - -import numpy as np -import torch -from tqdm import tqdm - -from cosmos_transfer1.auxiliary.tokenizer.inference.utils import ( - load_decoder_model, - load_encoder_model, - load_model, - numpy2tensor, - pad_video_batch, - tensor2numpy, - unpad_video_batch, -) - - -class CausalVideoTokenizer(torch.nn.Module): - def __init__( - self, - checkpoint: str = None, - checkpoint_enc: str = None, - checkpoint_dec: str = None, - tokenizer_config: dict[str, Any] = None, - device: str = "cuda", - dtype: str = "bfloat16", - ) -> None: - super().__init__() - self._device = device - self._dtype = getattr(torch, dtype) - self._full_model = ( - load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None - ) - self._enc_model = ( - load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) - if checkpoint_enc is not None - else None - ) - self._dec_model = ( - load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) - if checkpoint_dec is not None - else None - ) - - @torch.no_grad() - def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: - """Reconstrcuts a batch of video tensors after embedding into a latent. - - Args: - video: The input video Bx3xTxHxW layout, range [-1..1]. - Returns: - The reconstructed video, layout Bx3xTxHxW, range [-1..1]. - """ - if self._full_model is not None: - output_tensor = self._full_model(input_tensor) - output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor - else: - output_latent = self.encode(input_tensor)[0] - output_tensor = self.decode(output_latent) - return output_tensor - - @torch.no_grad() - def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: - """Encodes a numpy video into a CausalVideo latent or code. - - Args: - input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. - Returns: - For causal continuous video (CV) tokenizer, the tuple contains: - - The latent embedding, Bx16x(t)x(h)x(w), where the compression - rate is (T/t x H/h x W/w), and channel dimension of 16. - For causal discrete video (DV) tokenizer, the tuple contains: - 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which - is formed by FSQ levels of (8,8,8,5,5,5). - 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate - is again (T/t x H/h x W/w), and channel dimension of 6. - """ - assert input_tensor.ndim == 5, "input video should be of 5D." - - output_latent = self._enc_model(input_tensor) - if isinstance(output_latent, torch.Tensor): - return output_latent - return output_latent[:-1] - - @torch.no_grad() - def decode(self, input_latent: torch.Tensor) -> torch.Tensor: - """Encodes a numpy video into a CausalVideo latent. - - Args: - input_latent: The continuous latent Bx16xtxhxw for CV, - or the discrete indices Bxtxhxw for DV. - Returns: - The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. - """ - assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete." - return self._dec_model(input_latent) - - def forward( - self, - video: np.ndarray, - temporal_window: int = 17, - ) -> np.ndarray: - """Reconstructs video using a pre-trained CausalTokenizer autoencoder. - Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer - in a sliding manner with a `temporal_window` size. - - Args: - video: The input video BxTxHxWx3 layout, range [0..255]. - temporal_window: The length of the temporal window to process, default=25. - Returns: - The reconstructed video in range [0..255], layout BxTxHxWx3. - """ - assert video.ndim == 5, "input video should be of 5D." - num_frames = video.shape[1] # can be of any length. - output_video_list = [] - for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): - # Input video for the current window. - start, end = idx * temporal_window, (idx + 1) * temporal_window - input_video = video[:, start:end, ...] - - # Spatio-temporally pad input_video so it's evenly divisible. - padded_input_video, crop_region = pad_video_batch(input_video) - input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device) - output_tensor = self.autoencode(input_tensor) - padded_output_video = tensor2numpy(output_tensor) - output_video = unpad_video_batch(padded_output_video, crop_region) - - output_video_list.append(output_video) - return np.concatenate(output_video_list, axis=1) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/__init__.py deleted file mode 100644 index 0a1eddcd176ae2a4911ee2e8d6150012cc9a5e73..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/__init__.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - -from cosmos_transfer1.auxiliary.tokenizer.modules.distributions import GaussianDistribution, IdentityDistribution -from cosmos_transfer1.auxiliary.tokenizer.modules.layers2d import Decoder, Encoder -from cosmos_transfer1.auxiliary.tokenizer.modules.layers3d import ( - DecoderBase, - DecoderFactorized, - EncoderBase, - EncoderFactorized, -) -from cosmos_transfer1.auxiliary.tokenizer.modules.quantizers import ( - FSQuantizer, - LFQuantizer, - ResidualFSQuantizer, - VectorQuantizer, -) - - -class EncoderType(Enum): - Default = Encoder - - -class DecoderType(Enum): - Default = Decoder - - -class Encoder3DType(Enum): - BASE = EncoderBase - FACTORIZED = EncoderFactorized - - -class Decoder3DType(Enum): - BASE = DecoderBase - FACTORIZED = DecoderFactorized - - -class ContinuousFormulation(Enum): - VAE = GaussianDistribution - AE = IdentityDistribution - - -class DiscreteQuantizer(Enum): - VQ = VectorQuantizer - LFQ = LFQuantizer - FSQ = FSQuantizer - RESFSQ = ResidualFSQuantizer diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/distributions.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/distributions.py deleted file mode 100644 index 2347f7453611d9fea87d0f530bd8e54f02c3f39e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/distributions.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The distribution modes to use for continuous image tokenizers.""" - -import torch - - -class IdentityDistribution(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, parameters): - return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) - - -class GaussianDistribution(torch.nn.Module): - def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): - super().__init__() - self.min_logvar = min_logvar - self.max_logvar = max_logvar - - def sample(self, mean, logvar): - std = torch.exp(0.5 * logvar) - return mean + std * torch.randn_like(mean) - - def forward(self, parameters): - mean, logvar = torch.chunk(parameters, 2, dim=1) - logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) - return self.sample(mean, logvar), (mean, logvar) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/layers2d.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/layers2d.py deleted file mode 100644 index 8d6a07117e1b12a70e3830761ba0b2efb9c816ba..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/layers2d.py +++ /dev/null @@ -1,329 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The model definition for Continuous 2D layers - -Adapted from: https://github.com/CompVis/stable-diffusion/blob/ -21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py - -[Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors] -https://github.com/CompVis/stable-diffusion/blob/ -21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/LICENSE -""" - -import math - -import numpy as np - -# pytorch_diffusion + derived encoder decoder -import torch -import torch.nn as nn -import torch.nn.functional as F -from loguru import logger as logging - -from cosmos_transfer1.auxiliary.tokenizer.modules.patching import Patcher, UnPatcher -from cosmos_transfer1.auxiliary.tokenizer.modules.utils import Normalize, nonlinearity - - -class Upsample(nn.Module): - def __init__(self, in_channels: int): - super().__init__() - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) - return self.conv(x) - - -class Downsample(nn.Module): - def __init__(self, in_channels: int): - super().__init__() - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - return self.conv(x) - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels: int, - out_channels: int = None, - dropout: float, - **kwargs, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - - self.norm1 = Normalize(in_channels) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = Normalize(out_channels) - self.dropout = nn.Dropout(dropout) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.nin_shortcut = ( - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - if in_channels != out_channels - else nn.Identity() - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels: int): - super().__init__() - - self.norm = Normalize(in_channels) - self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # TODO (freda): Consider reusing implementations in Attn `imaginaire`, - # since than one is gonna be based on TransformerEngine's attn op, - # w/c could ease CP implementations. - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h * w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c) ** (-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class Encoder(nn.Module): - def __init__( - self, - in_channels: int, - channels: int, - channels_mult: list[int], - num_res_blocks: int, - attn_resolutions: list[int], - dropout: float, - resolution: int, - z_channels: int, - spatial_compression: int, - **ignore_kwargs, - ): - super().__init__() - self.num_resolutions = len(channels_mult) - self.num_res_blocks = num_res_blocks - - # Patcher. - patch_size = ignore_kwargs.get("patch_size", 1) - self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) - in_channels = in_channels * patch_size * patch_size - - # calculate the number of downsample operations - self.num_downsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) - assert ( - self.num_downsamples <= self.num_resolutions - ), f"we can only downsample {self.num_resolutions} times at most" - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1) - - curr_res = resolution // patch_size - in_ch_mult = (1,) + tuple(channels_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = channels * in_ch_mult[i_level] - block_out = channels * channels_mult[i_level] - for _ in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level < self.num_downsamples: - down.downsample = Downsample(block_in) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.patcher(x) - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level < self.num_downsamples: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - out_channels: int, - channels: int, - channels_mult: list[int], - num_res_blocks: int, - attn_resolutions: int, - dropout: float, - resolution: int, - z_channels: int, - spatial_compression: int, - **ignore_kwargs, - ): - super().__init__() - self.num_resolutions = len(channels_mult) - self.num_res_blocks = num_res_blocks - - # UnPatcher. - patch_size = ignore_kwargs.get("patch_size", 1) - self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) - out_ch = out_channels * patch_size * patch_size - - # calculate the number of upsample operations - self.num_upsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) - assert self.num_upsamples <= self.num_resolutions, f"we can only upsample {self.num_resolutions} times at most" - - block_in = channels * channels_mult[self.num_resolutions - 1] - curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = channels * channels_mult[i_level] - for _ in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level >= (self.num_resolutions - self.num_upsamples): - up.upsample = Upsample(block_in) - curr_res = curr_res * 2 - self.up.insert(0, up) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z: torch.Tensor) -> torch.Tensor: - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level >= (self.num_resolutions - self.num_upsamples): - h = self.up[i_level].upsample(h) - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - h = self.unpatcher(h) - return h diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/layers3d.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/layers3d.py deleted file mode 100644 index b18b907511fc4aa60184a77f950ed169512c1399..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/layers3d.py +++ /dev/null @@ -1,969 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The model definition for 3D layers - -Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/ -9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889 - -[MIT License Copyright (c) 2023 Phil Wang] -https://github.com/lucidrains/magvit2-pytorch/blob/ -9f49074179c912736e617d61b32be367eb5f993a/LICENSE -""" -import math -from typing import Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from loguru import logger as logging - -from cosmos_transfer1.auxiliary.tokenizer.modules.patching import Patcher, Patcher3D, UnPatcher, UnPatcher3D -from cosmos_transfer1.auxiliary.tokenizer.modules.utils import ( - CausalNormalize, - batch2space, - batch2time, - cast_tuple, - is_odd, - nonlinearity, - replication_pad, - space2batch, - time2batch, -) - -_LEGACY_NUM_GROUPS = 32 - - -class CausalConv3d(nn.Module): - def __init__( - self, - chan_in: int = 1, - chan_out: int = 1, - kernel_size: Union[int, Tuple[int, int, int]] = 3, - pad_mode: str = "constant", - **kwargs, - ): - super().__init__() - kernel_size = cast_tuple(kernel_size, 3) - - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - - assert is_odd(height_kernel_size) and is_odd(width_kernel_size) - - dilation = kwargs.pop("dilation", 1) - stride = kwargs.pop("stride", 1) - time_stride = kwargs.pop("time_stride", 1) - time_dilation = kwargs.pop("time_dilation", 1) - padding = kwargs.pop("padding", 1) - - self.pad_mode = pad_mode - time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) - self.time_pad = time_pad - - self.spatial_pad = (padding, padding, padding, padding) - - stride = (time_stride, stride, stride) - dilation = (time_dilation, dilation, dilation) - self.conv3d = nn.Conv3d( - chan_in, - chan_out, - kernel_size, - stride=stride, - dilation=dilation, - **kwargs, - ) - - def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: - x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) - x = torch.cat([x_prev, x], dim=2) - padding = self.spatial_pad + (0, 0) - return F.pad(x, padding, mode=self.pad_mode, value=0.0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self._replication_pad(x) - return self.conv3d(x) - - -class CausalUpsample3d(nn.Module): - def __init__(self, in_channels: int) -> None: - super().__init__() - self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) - time_factor = 1.0 + 1.0 * (x.shape[2] > 1) - if isinstance(time_factor, torch.Tensor): - time_factor = time_factor.item() - x = x.repeat_interleave(int(time_factor), dim=2) - # TODO(freda): Check if this causes temporal inconsistency. - # Shoule reverse the order of the following two ops, - # better perf and better temporal smoothness. - x = self.conv(x) - return x[..., int(time_factor - 1) :, :, :] - - -class CausalDownsample3d(nn.Module): - def __init__(self, in_channels: int) -> None: - super().__init__() - self.conv = CausalConv3d( - in_channels, - in_channels, - kernel_size=3, - stride=2, - time_stride=2, - padding=0, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - pad = (0, 1, 0, 1, 0, 0) - x = F.pad(x, pad, mode="constant", value=0) - x = replication_pad(x) - x = self.conv(x) - return x - - -class CausalHybridUpsample3d(nn.Module): - def __init__( - self, - in_channels: int, - spatial_up: bool = True, - temporal_up: bool = True, - **kwargs, - ) -> None: - super().__init__() - self.conv1 = CausalConv3d( - in_channels, - in_channels, - kernel_size=(3, 1, 1), - stride=1, - time_stride=1, - padding=0, - ) - self.conv2 = CausalConv3d( - in_channels, - in_channels, - kernel_size=(1, 3, 3), - stride=1, - time_stride=1, - padding=1, - ) - self.conv3 = CausalConv3d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - time_stride=1, - padding=0, - ) - self.spatial_up = spatial_up - self.temporal_up = temporal_up - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not self.spatial_up and not self.temporal_up: - return x - - # hybrid upsample temporally. - if self.temporal_up: - time_factor = 1.0 + 1.0 * (x.shape[2] > 1) - if isinstance(time_factor, torch.Tensor): - time_factor = time_factor.item() - x = x.repeat_interleave(int(time_factor), dim=2) - x = x[..., int(time_factor - 1) :, :, :] - x = self.conv1(x) + x - - # hybrid upsample spatially. - if self.spatial_up: - x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) - x = self.conv2(x) + x - - # final 1x1x1 conv. - x = self.conv3(x) - return x - - -class CausalHybridDownsample3d(nn.Module): - def __init__( - self, - in_channels: int, - spatial_down: bool = True, - temporal_down: bool = True, - **kwargs, - ) -> None: - super().__init__() - self.conv1 = CausalConv3d( - in_channels, - in_channels, - kernel_size=(1, 3, 3), - stride=2, - time_stride=1, - padding=0, - ) - self.conv2 = CausalConv3d( - in_channels, - in_channels, - kernel_size=(3, 1, 1), - stride=1, - time_stride=2, - padding=0, - ) - self.conv3 = CausalConv3d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - time_stride=1, - padding=0, - ) - self.spatial_down = spatial_down - self.temporal_down = temporal_down - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not self.spatial_down and not self.temporal_down: - return x - - # hybrid downsample spatially. - if self.spatial_down: - pad = (0, 1, 0, 1, 0, 0) - x = F.pad(x, pad, mode="constant", value=0) - x1 = self.conv1(x) - x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) - x = x1 + x2 - - # hybrid downsample temporally. - if self.temporal_down: - x = replication_pad(x) - x1 = self.conv2(x) - x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) - x = x1 + x2 - - # final 1x1x1 conv. - x = self.conv3(x) - return x - - -class CausalResnetBlock3d(nn.Module): - def __init__( - self, - *, - in_channels: int, - out_channels: int = None, - dropout: float, - num_groups: int, - ) -> None: - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - - self.norm1 = CausalNormalize(in_channels, num_groups=num_groups) - self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.nin_shortcut = ( - CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - if in_channels != out_channels - else nn.Identity() - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - x = self.nin_shortcut(x) - - return x + h - - -class CausalResnetBlockFactorized3d(nn.Module): - def __init__( - self, - *, - in_channels: int, - out_channels: int = None, - dropout: float, - num_groups: int, - ) -> None: - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - - self.norm1 = CausalNormalize(in_channels, num_groups=1) - self.conv1 = nn.Sequential( - CausalConv3d( - in_channels, - out_channels, - kernel_size=(1, 3, 3), - stride=1, - padding=1, - ), - CausalConv3d( - out_channels, - out_channels, - kernel_size=(3, 1, 1), - stride=1, - padding=0, - ), - ) - self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = nn.Sequential( - CausalConv3d( - out_channels, - out_channels, - kernel_size=(1, 3, 3), - stride=1, - padding=1, - ), - CausalConv3d( - out_channels, - out_channels, - kernel_size=(3, 1, 1), - stride=1, - padding=0, - ), - ) - self.nin_shortcut = ( - CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - if in_channels != out_channels - else nn.Identity() - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - x = self.nin_shortcut(x) - - return x + h - - -class CausalAttnBlock(nn.Module): - def __init__(self, in_channels: int, num_groups: int) -> None: - super().__init__() - - self.norm = CausalNormalize(in_channels, num_groups=num_groups) - self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - q, batch_size = time2batch(q) - k, batch_size = time2batch(k) - v, batch_size = time2batch(v) - - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h * w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c) ** (-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = batch2time(h_, batch_size) - h_ = self.proj_out(h_) - return x + h_ - - -class CausalTemporalAttnBlock(nn.Module): - def __init__(self, in_channels: int, num_groups: int) -> None: - super().__init__() - - self.norm = CausalNormalize(in_channels, num_groups=num_groups) - self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - q, batch_size, height = space2batch(q) - k, _, _ = space2batch(k) - v, _, _ = space2batch(v) - - bhw, c, t = q.shape - q = q.permute(0, 2, 1) # (bhw, t, c) - k = k.permute(0, 2, 1) # (bhw, t, c) - v = v.permute(0, 2, 1) # (bhw, t, c) - - w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) - w_ = w_ * (int(c) ** (-0.5)) - - # Apply causal mask - mask = torch.tril(torch.ones_like(w_)) - w_ = w_.masked_fill(mask == 0, float("-inf")) - w_ = F.softmax(w_, dim=2) - - # attend to values - h_ = torch.bmm(w_, v) # (bhw, t, c) - h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) - - h_ = batch2space(h_, batch_size, height) - h_ = self.proj_out(h_) - return x + h_ - - -class EncoderBase(nn.Module): - def __init__( - self, - in_channels: int, - channels: int, - channels_mult: list[int], - num_res_blocks: int, - attn_resolutions: list[int], - dropout: float, - resolution: int, - z_channels: int, - **ignore_kwargs, - ) -> None: - super().__init__() - self.num_resolutions = len(channels_mult) - self.num_res_blocks = num_res_blocks - - # Patcher. - patch_size = ignore_kwargs.get("patch_size", 1) - self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) - in_channels = in_channels * patch_size * patch_size - - # downsampling - self.conv_in = CausalConv3d(in_channels, channels, kernel_size=3, stride=1, padding=1) - - # num of groups for GroupNorm, num_groups=1 for LayerNorm. - num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) - curr_res = resolution // patch_size - in_ch_mult = (1,) + tuple(channels_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = channels * in_ch_mult[i_level] - block_out = channels * channels_mult[i_level] - for _ in range(self.num_res_blocks): - block.append( - CausalResnetBlock3d( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - num_groups=num_groups, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = CausalDownsample3d(block_in) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = CausalResnetBlock3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=num_groups, - ) - self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) - self.mid.block_2 = CausalResnetBlock3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=num_groups, - ) - - # end - self.norm_out = CausalNormalize(block_in, num_groups=num_groups) - self.conv_out = CausalConv3d(block_in, z_channels, kernel_size=3, stride=1, padding=1) - - def patcher3d(self, x: torch.Tensor) -> torch.Tensor: - x, batch_size = time2batch(x) - x = self.patcher(x) - x = batch2time(x, batch_size) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.patcher3d(x) - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - else: - # temporal downsample (last level) - time_factor = 1 + 1 * (hs[-1].shape[2] > 1) - if isinstance(time_factor, torch.Tensor): - time_factor = time_factor.item() - hs[-1] = replication_pad(hs[-1]) - hs.append( - F.avg_pool3d( - hs[-1], - kernel_size=[time_factor, 1, 1], - stride=[2, 1, 1], - ) - ) - - # middle - h = hs[-1] - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class DecoderBase(nn.Module): - def __init__( - self, - out_channels: int, - channels: int, - channels_mult: list[int], - num_res_blocks: int, - attn_resolutions: list[int], - dropout: float, - resolution: int, - z_channels: int, - **ignore_kwargs, - ): - super().__init__() - self.num_resolutions = len(channels_mult) - self.num_res_blocks = num_res_blocks - - # UnPatcher. - patch_size = ignore_kwargs.get("patch_size", 1) - self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) - out_ch = out_channels * patch_size * patch_size - - block_in = channels * channels_mult[self.num_resolutions - 1] - curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # num of groups for GroupNorm, num_groups=1 for LayerNorm. - num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) - - # middle - self.mid = nn.Module() - self.mid.block_1 = CausalResnetBlock3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=num_groups, - ) - self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) - self.mid.block_2 = CausalResnetBlock3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=num_groups, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = channels * channels_mult[i_level] - for _ in range(self.num_res_blocks + 1): - block.append( - CausalResnetBlock3d( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - num_groups=num_groups, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = CausalUpsample3d(block_in) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = CausalNormalize(block_in, num_groups=num_groups) - self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor: - x, batch_size = time2batch(x) - x = self.unpatcher(x) - x = batch2time(x, batch_size) - - return x - - def forward(self, z): - h = self.conv_in(z) - - # middle block. - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # decoder blocks. - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - else: - # temporal upsample (last level) - time_factor = 1.0 + 1.0 * (h.shape[2] > 1) - if isinstance(time_factor, torch.Tensor): - time_factor = time_factor.item() - h = h.repeat_interleave(int(time_factor), dim=2) - h = h[..., int(time_factor - 1) :, :, :] - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - h = self.unpatcher3d(h) - return h - - -class EncoderFactorized(nn.Module): - def __init__( - self, - in_channels: int, - channels: int, - channels_mult: list[int], - num_res_blocks: int, - attn_resolutions: list[int], - dropout: float, - resolution: int, - z_channels: int, - spatial_compression: int = 16, - temporal_compression: int = 8, - **ignore_kwargs, - ) -> None: - super().__init__() - self.num_resolutions = len(channels_mult) - self.num_res_blocks = num_res_blocks - - # Patcher. - patch_size = ignore_kwargs.get("patch_size", 1) - self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) - in_channels = in_channels * patch_size * patch_size * patch_size - - # calculate the number of downsample operations - self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) - assert ( - self.num_spatial_downs <= self.num_resolutions - ), f"Spatially downsample {self.num_resolutions} times at most" - - self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) - assert ( - self.num_temporal_downs <= self.num_resolutions - ), f"Temporally downsample {self.num_resolutions} times at most" - - # downsampling - self.conv_in = nn.Sequential( - CausalConv3d( - in_channels, - channels, - kernel_size=(1, 3, 3), - stride=1, - padding=1, - ), - CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0), - ) - - curr_res = resolution // patch_size - in_ch_mult = (1,) + tuple(channels_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = channels * in_ch_mult[i_level] - block_out = channels * channels_mult[i_level] - for _ in range(self.num_res_blocks): - block.append( - CausalResnetBlockFactorized3d( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - num_groups=1, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append( - nn.Sequential( - CausalAttnBlock(block_in, num_groups=1), - CausalTemporalAttnBlock(block_in, num_groups=1), - ) - ) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - spatial_down = i_level < self.num_spatial_downs - temporal_down = i_level < self.num_temporal_downs - down.downsample = CausalHybridDownsample3d( - block_in, - spatial_down=spatial_down, - temporal_down=temporal_down, - ) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = CausalResnetBlockFactorized3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=1, - ) - self.mid.attn_1 = nn.Sequential( - CausalAttnBlock(block_in, num_groups=1), - CausalTemporalAttnBlock(block_in, num_groups=1), - ) - self.mid.block_2 = CausalResnetBlockFactorized3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=1, - ) - - # end - self.norm_out = CausalNormalize(block_in, num_groups=1) - self.conv_out = nn.Sequential( - CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1), - CausalConv3d( - z_channels, - z_channels, - kernel_size=(3, 1, 1), - stride=1, - padding=0, - ), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.patcher3d(x) - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class DecoderFactorized(nn.Module): - def __init__( - self, - out_channels: int, - channels: int, - channels_mult: list[int], - num_res_blocks: int, - attn_resolutions: list[int], - dropout: float, - resolution: int, - z_channels: int, - spatial_compression: int = 16, - temporal_compression: int = 8, - **ignore_kwargs, - ): - super().__init__() - self.num_resolutions = len(channels_mult) - self.num_res_blocks = num_res_blocks - - # UnPatcher. - patch_size = ignore_kwargs.get("patch_size", 1) - self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) - out_ch = out_channels * patch_size * patch_size * patch_size - - # calculate the number of upsample operations - self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) - assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most" - self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) - assert ( - self.num_temporal_ups <= self.num_resolutions - ), f"Temporally upsample {self.num_resolutions} times at most" - - block_in = channels * channels_mult[self.num_resolutions - 1] - curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = nn.Sequential( - CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1), - CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0), - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = CausalResnetBlockFactorized3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=1, - ) - self.mid.attn_1 = nn.Sequential( - CausalAttnBlock(block_in, num_groups=1), - CausalTemporalAttnBlock(block_in, num_groups=1), - ) - self.mid.block_2 = CausalResnetBlockFactorized3d( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - num_groups=1, - ) - - legacy_mode = ignore_kwargs.get("legacy_mode", False) - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = channels * channels_mult[i_level] - for _ in range(self.num_res_blocks + 1): - block.append( - CausalResnetBlockFactorized3d( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - num_groups=1, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append( - nn.Sequential( - CausalAttnBlock(block_in, num_groups=1), - CausalTemporalAttnBlock(block_in, num_groups=1), - ) - ) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - # The layer index for temporal/spatial downsampling performed - # in the encoder should correspond to the layer index in - # reverse order where upsampling is performed in the decoder. - # If you've a pre-trained model, you can simply finetune. - i_level_reverse = self.num_resolutions - i_level - 1 - if legacy_mode: - temporal_up = i_level_reverse < self.num_temporal_ups - else: - temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 - spatial_up = temporal_up or ( - i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups - ) - up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = CausalNormalize(block_in, num_groups=1) - self.conv_out = nn.Sequential( - CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), - CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), - ) - - def forward(self, z): - h = self.conv_in(z) - - # middle block. - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # decoder blocks. - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - h = self.unpatcher3d(h) - return h diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/patching.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/patching.py deleted file mode 100644 index 50b7e88150e8b3e86bc9052c832feb89aed71924..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/patching.py +++ /dev/null @@ -1,311 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The patcher and unpatcher implementation for 2D and 3D data. - -The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions. -One on the rows and one on the columns. -For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2. -We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component. -For H component, we can use a 1D convolution with kernel [1, -1] and stride 2. -Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all - as we need to support downsampling for more than 2x. -For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be. - [3, 256, 256] -> [12, 128, 128] -> [48, 64, 64] -""" - -import torch -import torch.nn.functional as F -from einops import rearrange - -_WAVELETS = { - "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), - "rearrange": torch.tensor([1.0, 1.0]), -} -_PERSISTENT = False - - -class Patcher(torch.nn.Module): - """A module to convert image tensors into patches using torch operations. - - The main difference from `class Patching` is that this module implements - all operations using torch, rather than python or numpy, for efficiency purpose. - - It's bit-wise identical to the Patching module outputs, with the added - benefit of being torch.jit scriptable. - """ - - def __init__(self, patch_size=1, patch_method="haar"): - super().__init__() - self.patch_size = patch_size - self.patch_method = patch_method - self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) - self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) - self.register_buffer( - "_arange", - torch.arange(_WAVELETS[patch_method].shape[0]), - persistent=_PERSISTENT, - ) - for param in self.parameters(): - param.requires_grad = False - - def forward(self, x): - if self.patch_method == "haar": - return self._haar(x) - elif self.patch_method == "rearrange": - return self._arrange(x) - else: - raise ValueError("Unknown patch method: " + self.patch_method) - - def _dwt(self, x, mode="reflect", rescale=False): - dtype = x.dtype - h = self.wavelets - - n = h.shape[0] - g = x.shape[1] - hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) - hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) - hh = hh.to(dtype=dtype) - hl = hl.to(dtype=dtype) - - x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) - xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) - xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) - xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) - xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) - xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) - xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) - - out = torch.cat([xll, xlh, xhl, xhh], dim=1) - if rescale: - out = out / 2 - return out - - def _haar(self, x): - for _ in self.range: - x = self._dwt(x, rescale=True) - return x - - def _arrange(self, x): - x = rearrange( - x, - "b c (h p1) (w p2) -> b (c p1 p2) h w", - p1=self.patch_size, - p2=self.patch_size, - ).contiguous() - return x - - -class Patcher3D(Patcher): - """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" - - def __init__(self, patch_size=1, patch_method="haar"): - super().__init__(patch_method=patch_method, patch_size=patch_size) - self.register_buffer( - "patch_size_buffer", - patch_size * torch.ones([1], dtype=torch.int32), - persistent=_PERSISTENT, - ) - - def _dwt(self, x, wavelet, mode="reflect", rescale=False): - dtype = x.dtype - h = self.wavelets - - n = h.shape[0] - g = x.shape[1] - hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) - hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) - hh = hh.to(dtype=dtype) - hl = hl.to(dtype=dtype) - - # Handles temporal axis. - x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) - xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) - xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) - - # Handles spatial axes. - xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - - xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - - out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) - if rescale: - out = out / (2 * torch.sqrt(torch.tensor(2.0))) - return out - - def _haar(self, x): - xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) - x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) - for _ in self.range: - x = self._dwt(x, "haar", rescale=True) - return x - - def _arrange(self, x): - xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) - x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) - x = rearrange( - x, - "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", - p1=self.patch_size, - p2=self.patch_size, - p3=self.patch_size, - ).contiguous() - return x - - -class UnPatcher(torch.nn.Module): - """A module to convert patches into image tensorsusing torch operations. - - The main difference from `class Unpatching` is that this module implements - all operations using torch, rather than python or numpy, for efficiency purpose. - - It's bit-wise identical to the Unpatching module outputs, with the added - benefit of being torch.jit scriptable. - """ - - def __init__(self, patch_size=1, patch_method="haar"): - super().__init__() - self.patch_size = patch_size - self.patch_method = patch_method - self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) - self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) - self.register_buffer( - "_arange", - torch.arange(_WAVELETS[patch_method].shape[0]), - persistent=_PERSISTENT, - ) - for param in self.parameters(): - param.requires_grad = False - - def forward(self, x): - if self.patch_method == "haar": - return self._ihaar(x) - elif self.patch_method == "rearrange": - return self._iarrange(x) - else: - raise ValueError("Unknown patch method: " + self.patch_method) - - def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): - dtype = x.dtype - h = self.wavelets - n = h.shape[0] - - g = x.shape[1] // 4 - hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) - hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) - hh = hh.to(dtype=dtype) - hl = hl.to(dtype=dtype) - - xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) - - # Inverse transform. - yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) - yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) - yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) - yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) - y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) - y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) - - if rescale: - y = y * 2 - return y - - def _ihaar(self, x): - for _ in self.range: - x = self._idwt(x, "haar", rescale=True) - return x - - def _iarrange(self, x): - x = rearrange( - x, - "b (c p1 p2) h w -> b c (h p1) (w p2)", - p1=self.patch_size, - p2=self.patch_size, - ) - return x - - -class UnPatcher3D(UnPatcher): - """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" - - def __init__(self, patch_size=1, patch_method="haar"): - super().__init__(patch_method=patch_method, patch_size=patch_size) - - def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): - dtype = x.dtype - h = self.wavelets - n = h.shape[0] - - g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. - hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) - hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) - hl = hl.to(dtype=dtype) - hh = hh.to(dtype=dtype) - - xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) - - # Height height transposed convolutions. - xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - - xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - - xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - - xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) - - # Handles width transposed convolutions. - xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) - - # Handles time axis transposed convolutions. - x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) - x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) - - if rescale: - x = x * (2 * torch.sqrt(torch.tensor(2.0))) - return x - - def _ihaar(self, x): - for _ in self.range: - x = self._idwt(x, "haar", rescale=True) - x = x[:, :, self.patch_size - 1 :, ...] - return x - - def _iarrange(self, x): - x = rearrange( - x, - "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", - p1=self.patch_size, - p2=self.patch_size, - p3=self.patch_size, - ) - x = x[:, :, self.patch_size - 1 :, ...] - return x diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/quantizers.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/quantizers.py deleted file mode 100644 index 74dfec5112f91c2bbb760e0af790434849cad6d0..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/quantizers.py +++ /dev/null @@ -1,513 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Quantizers for discrete image and video tokenization.""" - -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import reduce -from loguru import logger as logging - -from cosmos_transfer1.auxiliary.tokenizer.modules.utils import ( - default, - entropy, - pack_one, - rearrange, - round_ste, - unpack_one, -) - - -class ResidualFSQuantizer(nn.Module): - """Residual Finite Scalar Quantization - - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf - """ - - def __init__(self, levels: list[int], num_quantizers: int, **ignore_kwargs): - super().__init__() - self.dtype = ignore_kwargs.get("dtype", torch.float32) - self.layers = nn.ModuleList([FSQuantizer(levels=levels) for _ in range(num_quantizers)]) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - indices_stack = [] - residual = x - quantized_out = 0 - loss_out = 0 - for i, layer in enumerate(self.layers): - quant_indices, z, loss = layer(residual) - indices_stack.append(quant_indices) - residual = residual - z.detach() - quantized_out = quantized_out + z - loss_out = loss_out + loss - self.residual = residual - indices = torch.stack(indices_stack, dim=1) - return indices, quantized_out.to(self.dtype), loss_out.to(self.dtype) - - def indices_to_codes(self, indices_stack: torch.Tensor) -> torch.Tensor: - quantized_out = 0 - for layer, indices in zip(self.layers, indices_stack.transpose(0, 1)): - quantized_out += layer.indices_to_codes(indices) - return quantized_out - - -class FSQuantizer(nn.Module): - """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 - - Code adapted from Jax version in Appendix A.1. - - Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ - vector_quantize_pytorch/finite_scalar_quantization.py - [Copyright (c) 2020 Phil Wang] - https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE - """ - - def __init__( - self, - levels: list[int], - dim: Optional[int] = None, - num_codebooks=1, - keep_num_codebooks_dim: Optional[bool] = None, - scale: Optional[float] = None, - **ignore_kwargs, - ): - super().__init__() - self.dtype = ignore_kwargs.get("dtype", torch.bfloat16) - _levels = torch.tensor(levels, dtype=torch.int32) - self.register_buffer("_levels", _levels, persistent=False) - - _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) - self.register_buffer("_basis", _basis, persistent=False) - - self.scale = scale - - codebook_dim = len(levels) - self.codebook_dim = codebook_dim - - effective_codebook_dim = codebook_dim * num_codebooks - self.num_codebooks = num_codebooks - self.effective_codebook_dim = effective_codebook_dim - - keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) - assert not (num_codebooks > 1 and not keep_num_codebooks_dim) - self.keep_num_codebooks_dim = keep_num_codebooks_dim - - self.dim = default(dim, len(_levels) * num_codebooks) - - has_projections = self.dim != effective_codebook_dim - self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() - self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() - self.has_projections = has_projections - - self.codebook_size = self._levels.prod().item() - - implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) - self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) - - def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: - """Bound `z`, an array of shape (..., d).""" - half_l = (self._levels - 1) * (1 + eps) / 2 - offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) - shift = (offset / half_l).atanh() - return (z + shift).tanh() * half_l - offset - - def quantize(self, z: torch.Tensor) -> torch.Tensor: - """Quantizes z, returns quantized zhat, same shape as z.""" - quantized = round_ste(self.bound(z)) - half_width = self._levels // 2 # Renormalize to [-1, 1]. - return quantized / half_width - - def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: - half_width = self._levels // 2 - return (zhat_normalized * half_width) + half_width - - def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: - half_width = self._levels // 2 - return (zhat - half_width) / half_width - - def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: - """Converts a `code` to an index in the codebook.""" - assert zhat.shape[-1] == self.codebook_dim - zhat = self._scale_and_shift(zhat).float() - return (zhat * self._basis).sum(dim=-1).to(torch.int32) - - def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: - """Inverse of `codes_to_indices`.""" - is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) - indices = rearrange(indices, "... -> ... 1") - codes_non_centered = (indices // self._basis) % self._levels - codes = self._scale_and_shift_inverse(codes_non_centered) - - if self.keep_num_codebooks_dim: - codes = rearrange(codes, "... c d -> ... (c d)") - - if project_out: - codes = self.project_out(codes) - - if is_img_or_video: - codes = rearrange(codes, "b ... d -> b d ...") - - return codes.to(self.dtype) - - def forward(self, z: torch.Tensor) -> torch.Tensor: - """ - einstein notation - b - batch - n - sequence (or flattened spatial dimensions) - d - feature dimension, which is also log2(codebook size) - c - number of codebook dim - """ - is_img_or_video = z.ndim >= 4 - - # standardize image or video into (batch, seq, dimension) - - if is_img_or_video: - z = rearrange(z, "b d ... -> b ... d") - z, ps = pack_one(z, "b * d") - - assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" - - z = self.project_in(z) - - z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) - - codes = self.quantize(z) - indices = self.codes_to_indices(codes) - - codes = rearrange(codes, "b n c d -> b n (c d)") - - out = self.project_out(codes) - - # reconstitute image or video dimensions - - if is_img_or_video: - out = unpack_one(out, ps, "b * d") - out = rearrange(out, "b ... d -> b d ...") - indices = unpack_one(indices, ps, "b * c") - dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) - else: - dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1) - - if not self.keep_num_codebooks_dim: - indices = rearrange(indices, "... 1 -> ...") - - return (indices, out.to(self.dtype), dummy_loss) - - -class VectorQuantizer(nn.Module): - """Improved version over VectorQuantizer. Mostly - avoids costly matrix multiplications and allows for post-hoc remapping of indices. - - Adapted from: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/ - taming/modules/vqvae/quantize.py - - [Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer] - https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/License.txt - """ - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - beta: float = 0.25, - remap: str = None, - unknown_index: str = "random", - sane_index_shape: bool = False, - legacy: bool = True, - use_norm=False, - **ignore_kwargs, - ): - super().__init__() - self.n_e = num_embeddings - self.e_dim = embedding_dim - self.beta = beta - self.legacy = legacy - self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = num_embeddings - - self.sane_index_shape = sane_index_shape - self.dtype = ignore_kwargs.get("dtype", torch.float32) - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" - assert rescale_logits is False, "Only for interface compatible with Gumbel" - assert return_logits is False, "Only for interface compatible with Gumbel" - z = rearrange(z, "b c h w -> b h w c").contiguous() - z_flattened = z.view(-1, self.e_dim) - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 - * torch.einsum( - "bd,dn->bn", - z_flattened, - rearrange(self.embedding.weight, "n d -> d n"), - ) - ) - - encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - encodings = torch.zeros(encoding_indices.shape[0], self.n_e, device=z.device) - encodings.scatter_(1, encoding_indices, 1) - z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape) - min_encodings = None - - z_q, z = self.norm(z_q), self.norm(z) - - # compute loss for embedding - commit_loss = torch.mean((z_q - z.detach()) ** 2, dim=[1, 2, 3], keepdim=True) - emb_loss = torch.mean((z_q.detach() - z) ** 2, dim=[1, 2, 3], keepdim=True) - if not self.legacy: - loss = self.beta * emb_loss + commit_loss - else: - loss = emb_loss + self.beta * commit_loss - - # preserve gradients - z_q = z + (z_q - z).detach() - avg_probs = torch.mean(encodings, dim=0) - perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) - - # reshape back to match original input shape - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - - if self.remap is not None: - min_encoding_indices = encoding_indices.squeeze(1).reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(encoding_indices.squeeze(1)) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - # TODO: return (indices, z_q, loss) - return ( - z_q, - loss, - ( - encoding_indices.squeeze(1), - min_encodings, - commit_loss.mean().detach(), - self.beta * emb_loss.mean().detach(), - perplexity.mean().detach(), - ), - ) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class LFQuantizer(nn.Module): - """Lookup-Free Quantization - - Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ - vector_quantize_pytorch/lookup_free_quantization.py - [Copyright (c) 2020 Phil Wang] - https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE - """ - - def __init__( - self, - *, - codebook_size: int, - codebook_dim: int, - embed_dim: Optional[int] = None, # if None, use codebook_dim - entropy_loss_weight=0.1, - commitment_loss_weight=0.25, - default_temp: float = 0.01, - entropy_loss: bool = False, - **ignore_kwargs, - ): - """Lookup-Free Quantization - - Args: - codebook_size (int): The number of entries in the codebook. - codebook_dim (int): The number of bits in each code. - embed_dim (Optional[int], optional): The dimension of the input embedding. Defaults to None. - entropy_loss_weight (float, optional): Whether to use entropy loss. Defaults to 0.1. - commitment_loss_weight (float, optional): Weight for commitment loss. Defaults to 0.25. - default_temp (float, optional): The temprature to use. Defaults to 0.01. - entropy_loss (bool, optional): Flag for entropy loss. Defaults to False. - """ - super().__init__() - self.entropy_loss = entropy_loss - self.codebook_dim = codebook_dim - self.default_temp = default_temp - self.entrop_loss_weight = entropy_loss_weight - self.commitment_loss_weight = commitment_loss_weight - embed_dim = embed_dim or codebook_dim - - has_projections = embed_dim != codebook_dim - self.project_in = nn.Linear(embed_dim, codebook_dim) if has_projections else nn.Identity() - self.project_out = nn.Linear(codebook_dim, embed_dim) if has_projections else nn.Identity() - logging.info(f"LFQ: has_projections={has_projections}, dim_in={embed_dim}, codebook_dim={codebook_dim}") - - self.dtype = ignore_kwargs.get("dtype", torch.float32) - - if entropy_loss: - assert 2**codebook_dim == codebook_size, "codebook size must be 2 ** codebook_dim" - self.codebook_size = codebook_size - - self.register_buffer( - "mask", - 2 ** torch.arange(codebook_dim - 1, -1, -1), - persistent=False, - ) - self.register_buffer("zero", torch.tensor(0.0), persistent=False) - - all_codes = torch.arange(codebook_size) - bits = ((all_codes[..., None].int() & self.mask) != 0).float() - codebook = 2 * bits - 1.0 - - self.register_buffer("codebook", codebook, persistent=False) # [codebook_size, codebook_dim] - - def forward(self, z: torch.Tensor, temp: float = None) -> torch.Tensor: - temp = temp or self.default_temp - - z = rearrange(z, "b d ... -> b ... d") - z, ps = pack_one(z, "b * d") - z = self.project_in(z) - - # split out number of codebooks - z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) - - # quantization - original_input = z - - codebook_value = torch.ones_like(z) - z_q = torch.where(z > 0, codebook_value, -codebook_value) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # commit loss - commit_loss = ((original_input - z_q.detach()) ** 2).mean(dim=[1, 2, 3]) - - z_q = rearrange(z_q, "b n c d -> b n (c d)") - z_q = self.project_out(z_q) - - # reshape - z_q = unpack_one(z_q, ps, "b * d") - z_q = rearrange(z_q, "b ... d -> b d ...") - - loss = self.commitment_loss_weight * commit_loss - - # entropy loss (eq-5) - if self.entropy_loss: - # indices - indices = reduce((z > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") - indices = unpack_one(indices, ps, "b * c") - indices = rearrange(indices, "... 1 -> ...") - - distance = -2 * torch.einsum( - "... i d, j d -> ... i j", - original_input, - self.codebook.to(original_input.dtype), - ) - prob = (-distance / temp).softmax(dim=-1) - per_sample_entropy = entropy(prob).mean(dim=[1, 2]) - avg_prob = reduce(prob, "... c d -> c d", "mean") - codebook_entropy = entropy(avg_prob).mean() - entropy_aux_loss = per_sample_entropy - codebook_entropy - - loss += self.entrop_loss_weight * entropy_aux_loss - - # TODO: return (indices, z_q, loss) - return ( - z_q, - loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), - ( - indices, - self.commitment_loss_weight * commit_loss.mean().detach(), - self.entrop_loss_weight * entropy_aux_loss.mean().detach(), - self.entrop_loss_weight * per_sample_entropy.mean().detach(), - self.entrop_loss_weight * codebook_entropy.mean().detach(), - ), - ) - else: - return ( - z_q, - loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), - self.commitment_loss_weight * commit_loss.mean().detach(), - ) - - -class InvQuantizerJit(nn.Module): - """Use for decoder_jit to trace quantizer in discrete tokenizer""" - - def __init__(self, quantizer): - super().__init__() - self.quantizer = quantizer - - def forward(self, indices: torch.Tensor): - codes = self.quantizer.indices_to_codes(indices) - return codes.to(self.quantizer.dtype) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/utils.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/utils.py deleted file mode 100644 index 578bf2fa3f15e0dbe05054d30fda380f6d93e53f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/modules/utils.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared utilities for the networks module.""" - -from typing import Any - -import torch -from einops import pack, rearrange, unpack - - -def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: - batch_size = x.shape[0] - return rearrange(x, "b c t h w -> (b t) c h w"), batch_size - - -def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: - return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) - - -def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: - batch_size, height = x.shape[0], x.shape[-2] - return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height - - -def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: - return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) - - -def cast_tuple(t: Any, length: int = 1) -> Any: - return t if isinstance(t, tuple) else ((t,) * length) - - -def replication_pad(x): - return torch.cat([x[:, :, :1, ...], x], dim=2) - - -def divisible_by(num: int, den: int) -> bool: - return (num % den) == 0 - - -def is_odd(n: int) -> bool: - return not divisible_by(n, 2) - - -def nonlinearity(x): - return x * torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - -class CausalNormalize(torch.nn.Module): - def __init__(self, in_channels, num_groups=1): - super().__init__() - self.norm = torch.nn.GroupNorm( - num_groups=num_groups, - num_channels=in_channels, - eps=1e-6, - affine=True, - ) - self.num_groups = num_groups - - def forward(self, x): - # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. - # All new models should use num_groups=1, otherwise causality is not guaranteed. - if self.num_groups == 1: - x, batch_size = time2batch(x) - return batch2time(self.norm(x), batch_size) - return self.norm(x) - - -def exists(v): - return v is not None - - -def default(*args): - for arg in args: - if exists(arg): - return arg - return None - - -def pack_one(t, pattern): - return pack([t], pattern) - - -def unpack_one(t, ps, pattern): - return unpack(t, ps, pattern)[0] - - -def round_ste(z: torch.Tensor) -> torch.Tensor: - """Round with straight through gradients.""" - zhat = z.round() - return z + (zhat - z).detach() - - -def log(t, eps=1e-5): - return t.clamp(min=eps).log() - - -def entropy(prob): - return (-prob * log(prob)).sum(dim=-1) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/__init__.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/__init__.py deleted file mode 100644 index 75f19a11fa29710eea3fa40b275ae0bff96a0bc3..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - -from cosmos_transfer1.auxiliary.tokenizer.networks.configs import continuous_image as continuous_image_dict -from cosmos_transfer1.auxiliary.tokenizer.networks.configs import continuous_video as continuous_video_dict -from cosmos_transfer1.auxiliary.tokenizer.networks.configs import discrete_image as discrete_image_dict -from cosmos_transfer1.auxiliary.tokenizer.networks.configs import discrete_video as discrete_video_dict -from cosmos_transfer1.auxiliary.tokenizer.networks.continuous_image import ContinuousImageTokenizer -from cosmos_transfer1.auxiliary.tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer -from cosmos_transfer1.auxiliary.tokenizer.networks.discrete_image import DiscreteImageTokenizer -from cosmos_transfer1.auxiliary.tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer - - -class TokenizerConfigs(Enum): - CI = continuous_image_dict - DI = discrete_image_dict - CV = continuous_video_dict - DV = discrete_video_dict - - -class TokenizerModels(Enum): - CI = ContinuousImageTokenizer - DI = DiscreteImageTokenizer - CV = CausalContinuousVideoTokenizer - DV = CausalDiscreteVideoTokenizer diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/configs.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/configs.py deleted file mode 100644 index d9864bf21ed5647349de535ba3a24d8ffbf38df5..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/configs.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The default image and video tokenizer configs.""" - -from cosmos_transfer1.auxiliary.tokenizer.modules import ( - ContinuousFormulation, - Decoder3DType, - DecoderType, - DiscreteQuantizer, - Encoder3DType, - EncoderType, -) - -continuous_image = dict( - # The attention resolution for res blocks. - attn_resolutions=[32], - # The base number of channels. - channels=128, - # The channel multipler for each resolution. - channels_mult=[2, 4, 4], - dropout=0.0, - in_channels=3, - # The spatial compression ratio. - spatial_compression=16, - # The number of layers in each res block. - num_res_blocks=2, - out_channels=3, - resolution=1024, - patch_size=4, - patch_method="haar", - # The output latent dimension (channels). - latent_channels=16, - # The encoder output channels just before sampling. - # Which is also the decoder's input channels. - z_channels=16, - # A factor over the z_channels, to get the total channels the encoder should output. - # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. - z_factor=1, - name="CI", - # What formulation to use, either "AE" or "VAE". - # Chose VAE here, since the pre-trained ckpt were of a VAE formulation. - formulation=ContinuousFormulation.AE.name, - # Specify type of encoder ["Default", "LiteVAE"] - encoder=EncoderType.Default.name, - # Specify type of decoder ["Default"] - decoder=DecoderType.Default.name, -) - -discrete_image = dict( - # The attention resolution for res blocks. - attn_resolutions=[32], - # The base number of channels. - channels=128, - # The channel multipler for each resolution. - channels_mult=[2, 4, 4], - dropout=0.0, - in_channels=3, - # The spatial compression ratio. - spatial_compression=16, - # The number of layers in each res block. - num_res_blocks=2, - out_channels=3, - resolution=1024, - patch_size=4, - patch_method="haar", - # The encoder output channels just before sampling. - z_channels=256, - # A factor over the z_channels, to get the total channels the encoder should output. - # for discrete tokenization, often we directly use the vector, so z_factor=1. - z_factor=1, - # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. - quantizer=DiscreteQuantizer.FSQ.name, - # The embedding dimension post-quantization, which is also the input channels of the decoder. - # Which is also the output - embedding_dim=6, - # The number of levels to use for fine-scalar quantization. - levels=[8, 8, 8, 5, 5, 5], - # The number of quantizers to use for residual fine-scalar quantization. - num_quantizers=4, - name="DI", - # Specify type of encoder ["Default", "LiteVAE"] - encoder=EncoderType.Default.name, - # Specify type of decoder ["Default"] - decoder=DecoderType.Default.name, -) - -continuous_video = dict( - attn_resolutions=[32], - channels=128, - channels_mult=[2, 4, 4], - dropout=0.0, - in_channels=3, - num_res_blocks=2, - out_channels=3, - resolution=1024, - patch_size=4, - patch_method="haar", - latent_channels=16, - z_channels=16, - z_factor=1, - num_groups=1, - legacy_mode=False, - spatial_compression=8, - temporal_compression=8, - formulation=ContinuousFormulation.AE.name, - encoder=Encoder3DType.FACTORIZED.name, - decoder=Decoder3DType.FACTORIZED.name, - name="CV", -) - -discrete_video = dict( - attn_resolutions=[32], - channels=128, - channels_mult=[2, 4, 4], - dropout=0.0, - in_channels=3, - num_res_blocks=2, - out_channels=3, - resolution=1024, - patch_size=4, - patch_method="haar", - z_channels=16, - z_factor=1, - num_groups=1, - legacy_mode=False, - spatial_compression=16, - temporal_compression=8, - quantizer=DiscreteQuantizer.FSQ.name, - embedding_dim=6, - levels=[8, 8, 8, 5, 5, 5], - encoder=Encoder3DType.FACTORIZED.name, - decoder=Decoder3DType.FACTORIZED.name, - name="DV", -) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/continuous_image.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/continuous_image.py deleted file mode 100644 index 5915cde1a757dc6343db342ee9c240b6604a4454..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/continuous_image.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The continuous image tokenizer with VAE or AE formulation for 2D data.""" - -from collections import OrderedDict, namedtuple - -import torch -from loguru import logger as logging -from torch import nn - -from cosmos_transfer1.auxiliary.tokenizer.modules import ContinuousFormulation, DecoderType, EncoderType - -NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) - - -class ContinuousImageTokenizer(nn.Module): - def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: - super().__init__() - self.name = kwargs.get("name", "ContinuousImageTokenizer") - self.latent_channels = latent_channels - - encoder_name = kwargs.get("encoder", EncoderType.Default.name) - self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) - - decoder_name = kwargs.get("decoder", DecoderType.Default.name) - self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) - - self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1) - self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1) - - formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) - self.distribution = ContinuousFormulation[formulation_name].value() - logging.info(f"{self.name} based on {formulation_name} formulation, with {kwargs}.") - - num_parameters = sum(param.numel() for param in self.parameters()) - logging.info(f"model={self.name}, num_parameters={num_parameters:,}") - logging.info(f"z_channels={z_channels}, latent_channels={self.latent_channels}.") - - def encoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("encoder", self.encoder), - ("quant_conv", self.quant_conv), - ("distribution", self.distribution), - ] - ) - ) - - def decoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("post_quant_conv", self.post_quant_conv), - ("decoder", self.decoder), - ] - ) - ) - - def last_decoder_layer(self): - return self.decoder.conv_out - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - return self.distribution(moments) - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input) -> dict[str, torch.Tensor] | NetworkEval: - latent, posteriors = self.encode(input) - dec = self.decode(latent) - if self.training: - return dict(reconstructions=dec, posteriors=posteriors, latent=latent) - return NetworkEval(reconstructions=dec, posteriors=posteriors, latent=latent) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/continuous_video.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/continuous_video.py deleted file mode 100644 index d62042db749b1b527901fb37308c3f93c6a9e51c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/continuous_video.py +++ /dev/null @@ -1,103 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The causal continuous video tokenizer with VAE or AE formulation for 3D data..""" -from collections import OrderedDict, namedtuple - -from loguru import logger as logging -from torch import nn - -from cosmos_transfer1.auxiliary.tokenizer.modules import ContinuousFormulation, Decoder3DType, Encoder3DType -from cosmos_transfer1.auxiliary.tokenizer.modules.layers3d import CausalConv3d - -NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) - - -class CausalContinuousVideoTokenizer(nn.Module): - def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: - super().__init__() - self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") - self.latent_channels = latent_channels - - encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) - self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) - if kwargs.get("temporal_compression", 4) == 4: - kwargs["channels_mult"] = [2, 4] - decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) - self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) - - self.quant_conv = CausalConv3d( - z_factor * z_channels, - z_factor * latent_channels, - kernel_size=1, - padding=0, - ) - self.post_quant_conv = CausalConv3d(latent_channels, z_channels, kernel_size=1, padding=0) - - formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) - self.distribution = ContinuousFormulation[formulation_name].value() - logging.info(f"{self.name} based on {formulation_name} formulation, with {kwargs}.") - - num_parameters = sum(param.numel() for param in self.parameters()) - logging.info(f"model={self.name}, num_parameters={num_parameters:,}") - logging.info(f"z_channels={z_channels}, latent_channels={self.latent_channels}.") - - def encoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("encoder", self.encoder), - ("quant_conv", self.quant_conv), - ("distribution", self.distribution), - ] - ) - ) - - def decoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("post_quant_conv", self.post_quant_conv), - ("decoder", self.decoder), - ] - ) - ) - - def last_decoder_layer(self): - return self.decoder.conv_out - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - return self.distribution(moments) - - def decode(self, z): - z = self.post_quant_conv(z) - return self.decoder(z) - - def forward(self, input): - latent, posteriors = self.encode(input) - reconstructions = self.decode(latent) - if self.training: - return dict( - reconstructions=reconstructions, - posteriors=posteriors, - latent=latent, - ) - return NetworkEval( - reconstructions=reconstructions, - posteriors=posteriors, - latent=latent, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/discrete_image.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/discrete_image.py deleted file mode 100644 index 9b85b41afe16844fd67291675f3f2585c8bab3ec..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/discrete_image.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The network definition for discrete image tokenization with VQ, LFQ, FSQ or ResidualFSQ.""" -from collections import OrderedDict, namedtuple - -import torch -from loguru import logger as logging -from torch import nn - -from cosmos_transfer1.auxiliary.tokenizer.modules import DecoderType, DiscreteQuantizer, EncoderType -from cosmos_transfer1.auxiliary.tokenizer.modules.quantizers import InvQuantizerJit - -NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) - - -class DiscreteImageTokenizer(nn.Module): - def __init__(self, z_channels: int, embedding_dim: int, **kwargs) -> None: - super().__init__() - self.name = kwargs.get("name", "DiscreteImageTokenizer") - self.embedding_dim = embedding_dim - - encoder_name = kwargs.get("encoder", EncoderType.Default.name) - self.encoder = EncoderType[encoder_name].value(z_channels=z_channels, **kwargs) - - decoder_name = kwargs.get("decoder", DecoderType.Default.name) - self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) - self.quant_conv = nn.Conv2d(z_channels, embedding_dim, 1) - self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1) - - quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) - if quantizer_name == DiscreteQuantizer.VQ.name: - assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." - kwargs.update(dict(embedding_dim=embedding_dim)) - elif quantizer_name == DiscreteQuantizer.LFQ.name: - assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." - assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." - elif quantizer_name == DiscreteQuantizer.FSQ.name: - assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." - elif quantizer_name == DiscreteQuantizer.RESFSQ.name: - assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}.name." - assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." - self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) - logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") - - num_parameters = sum(param.numel() for param in self.parameters()) - logging.info(f"model={self.name}, num_parameters={num_parameters:,}") - logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") - - def to(self, *args, **kwargs): - setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) - return super(DiscreteImageTokenizer, self).to(*args, **kwargs) - - def encoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("encoder", self.encoder), - ("quant_conv", self.quant_conv), - ("quantizer", self.quantizer), - ] - ) - ) - - def decoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("inv_quant", InvQuantizerJit(self.quantizer)), - ("post_quant_conv", self.post_quant_conv), - ("decoder", self.decoder), - ] - ) - ) - - def last_decoder_layer(self): - return self.decoder.conv_out - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return self.quantizer(h) - - def decode(self, quant): - quant = self.post_quant_conv(quant) - return self.decoder(quant) - - def decode_code(self, code_b): - quant_b = self.quantizer.indices_to_codes(code_b) - quant_b = self.post_quant_conv(quant_b) - return self.decoder(quant_b) - - def forward(self, input): - quant_info, quant_codes, quant_loss = self.encode(input) - reconstructions = self.decode(quant_codes) - if self.training: - return dict( - reconstructions=reconstructions, - quant_loss=quant_loss, - quant_info=quant_info, - ) - return NetworkEval( - reconstructions=reconstructions, - quant_loss=quant_loss, - quant_info=quant_info, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/discrete_video.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/discrete_video.py deleted file mode 100644 index 5c5ec0f334f2a64621f174fce2cb24d07754965c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/tokenizer/networks/discrete_video.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The network definition for discrete video tokenizer with VQ, LFQ, FSQ or ResidualFSQ.""" -from collections import OrderedDict, namedtuple - -import torch -from loguru import logger as logging -from torch import nn - -from cosmos_transfer1.auxiliary.tokenizer.modules import Decoder3DType, DiscreteQuantizer, Encoder3DType -from cosmos_transfer1.auxiliary.tokenizer.modules.layers3d import CausalConv3d -from cosmos_transfer1.auxiliary.tokenizer.modules.quantizers import InvQuantizerJit - -NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) - - -class CausalDiscreteVideoTokenizer(nn.Module): - def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: - super().__init__() - self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") - self.embedding_dim = embedding_dim - - encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) - self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) - - decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) - self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) - - self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) - self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) - - quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) - if quantizer_name == DiscreteQuantizer.VQ.name: - assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." - kwargs.update(dict(embedding_dim=embedding_dim)) - elif quantizer_name == DiscreteQuantizer.LFQ.name: - assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." - assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." - elif quantizer_name == DiscreteQuantizer.FSQ.name: - assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." - elif quantizer_name == DiscreteQuantizer.RESFSQ.name: - assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." - assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." - self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) - logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") - - num_parameters = sum(param.numel() for param in self.parameters()) - logging.info(f"model={self.name}, num_parameters={num_parameters:,}") - logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") - - def to(self, *args, **kwargs): - setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) - return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) - - def encoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("encoder", self.encoder), - ("quant_conv", self.quant_conv), - ("quantizer", self.quantizer), - ] - ) - ) - - def decoder_jit(self): - return nn.Sequential( - OrderedDict( - [ - ("inv_quant", InvQuantizerJit(self.quantizer)), - ("post_quant_conv", self.post_quant_conv), - ("decoder", self.decoder), - ] - ) - ) - - def last_decoder_layer(self): - return self.decoder.conv_out - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return self.quantizer(h) - - def decode(self, quant): - quant = self.post_quant_conv(quant) - return self.decoder(quant) - - def decode_code(self, code_b): - quant_b = self.quantizer.indices_to_codes(code_b) - quant_b = self.post_quant_conv(quant_b) - return self.decoder(quant_b) - - def forward(self, input): - quant_info, quant_codes, quant_loss = self.encode(input) - reconstructions = self.decode(quant_codes) - if self.training: - return dict( - reconstructions=reconstructions, - quant_loss=quant_loss, - quant_info=quant_info, - ) - return NetworkEval( - reconstructions=reconstructions, - quant_loss=quant_loss, - quant_info=quant_info, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/README.md b/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/README.md deleted file mode 100644 index e796a4f37be0aa42bb7e7a791ca0e069d201d8a5..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/README.md +++ /dev/null @@ -1,3 +0,0 @@ -## Cosmos-transfer1 Prompt Upsampler - -You can use `PYTHONPATH=. python cosmos_transfer1/auxiliary/upsampler/inference/upsampler_pipeline.py --prompt "a robotic arm hand over a coffee cup to a woman in a modern office." --input_video assets/example1_input_video.mp4` to upsample the prompt. diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/inference/upsampler_pipeline.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/inference/upsampler_pipeline.py deleted file mode 100644 index 1c7f321c9c8cf04a2eb5f1b2294f9023d5b1d444..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/inference/upsampler_pipeline.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -from cosmos_transfer1.auxiliary.upsampler.model.upsampler import PixtralPromptUpsampler -from cosmos_transfer1.utils.misc import extract_video_frames - - -def parse_args(): - parser = argparse.ArgumentParser(description="Prompt upsampler pipeline") - parser.add_argument("--prompt", type=str, required=False, help="Prompt to upsample") - parser.add_argument("--input_video", type=str, required=True, help="Path to input video file") - parser.add_argument( - "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" - ) - parser.add_argument( - "--offload_prompt_upsampler", action="store_true", help="Offload prompt upsampler model after inference" - ) - return parser.parse_args() - - -def main(): - args = parse_args() - model = PixtralPromptUpsampler(args.checkpoint_dir, offload_prompt_upsampler=args.offload_prompt_upsampler) - - # Upsample the prompt with the given video - frame_paths = extract_video_frames(args.input_video) - upsampled_prompt = model._prompt_upsample_with_offload(args.prompt, frame_paths) - print("Upsampled prompt:", upsampled_prompt) - - -if __name__ == "__main__": - import os - - rank = int(os.environ["RANK"]) - - dist_keys = [ - "RANK", - "LOCAL_RANK", - "WORLD_SIZE", - "LOCAL_WORLD_SIZE", - "GROUP_RANK", - "ROLE_RANK", - "ROLE_NAME", - "OMP_NUM_THREADS", - "MASTER_ADDR", - "MASTER_PORT", - "TORCHELASTIC_USE_AGENT_STORE", - "TORCHELASTIC_MAX_RESTARTS", - "TORCHELASTIC_RUN_ID", - "TORCH_NCCL_ASYNC_ERROR_HANDLING", - "TORCHELASTIC_ERROR_FILE", - ] - - for dist_key in dist_keys: - del os.environ[dist_key] - - if rank == 0: - main() diff --git a/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/model/upsampler.py b/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/model/upsampler.py deleted file mode 100644 index 33ec0a9863d9f20ca879414c323eaf89baea0de7..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/auxiliary/upsampler/model/upsampler.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import os -from typing import Union - -import torch -from vllm import LLM, SamplingParams - -from cosmos_transfer1.checkpoints import COSMOS_UPSAMPLER_CHECKPOINT -from cosmos_transfer1.utils.misc import extract_video_frames, image_to_base64 - - -class PixtralPromptUpsampler: - def __init__(self, checkpoint_dir: str, offload_prompt_upsampler: bool = False): - """ - Initializes the Upsampler model. - Args: - checkpoint_dir (str): The directory where model checkpoints are stored. - offload_prompt_upsampler (bool, optional): If True, the upsampler model will not be loaded during initialization. Defaults to False. - """ - - self.offload_prompt_upsampler = offload_prompt_upsampler - self.checkpoint_dir = checkpoint_dir - if not self.offload_prompt_upsampler: - self._load_upsampler_model() - - def _load_upsampler_model(self): - """ - Loads the upsampler model. - Sets: - self.upsampler_model: An instance of VLM initialized with the specified model configuration. - self.sampling_params: An instance of SamplingParams with predefined parameters. - """ - model_path = os.path.join(self.checkpoint_dir, COSMOS_UPSAMPLER_CHECKPOINT) - - self.upsampler_model = LLM( - model=model_path, - tensor_parallel_size=1, - tokenizer_mode="mistral", - gpu_memory_utilization=0.98, - max_model_len=4096, - max_num_seqs=2, - limit_mm_per_prompt={"image": 2}, - enable_prefix_caching=True, - ) - self.sampling_params = SamplingParams( - temperature=0.6, - max_tokens=300, - ) - - def _prompt_upsample_with_offload(self, prompt: str, video_path: Union[list, str]): - """ - Upsamples the given prompt using the upsampler model, with optional model offloading. - This method first checks if the upsampler model should be offloaded. If so, it loads the model, - performs the upsampling, and then offloads the model again if necessary. - Args: - prompt (str): The prompt to be upsampled. - image_paths (list): A list of paths to the images associated with the prompt. - Returns: - upsampled_prompt: The upsampled version of the input prompt. - """ - - if self.offload_prompt_upsampler: - self._load_upsampler_model() - - upsampled_prompt = self._prompt_upsample(prompt, video_path) - - if self.offload_prompt_upsampler: - self._offload_upsampler_model() - return upsampled_prompt - - def _prompt_upsample(self, prompt: str, video_path: Union[list, str]): - """ - Generates an upsampled image based on the provided prompt and image paths. - Args: - prompt (str): The textual prompt to guide the upsampling process. - image_paths (list of str): List of file paths to the images to be upsampled. - Returns: - str: The text output from the language model after processing the prompt and images. - """ - prompt = prompt if prompt else "describe the following images" - image_paths = video_path - if isinstance(video_path, str): - image_paths = extract_video_frames(video_path) - - message = [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_to_base64(image_path)}"}} - for image_path in image_paths - ] - + [ - {"type": "text", "text": prompt}, - ], - } - ] - - outputs = self.upsampler_model.chat( - [message], - sampling_params=self.sampling_params, - ) - output = outputs[0].outputs[0].text if outputs[0].outputs[0].text else "" - return str(output).strip() - - def _offload_upsampler_model(self): - """ - Offloads the upsampler model from memory. - This method deletes the `upsampler_model` attribute if it exists, sets it to None, - triggers garbage collection, and clears the CUDA cache to free up GPU memory. - """ - if self.upsampler_model: - del self.upsampler_model - self.upsampler_model = None - gc.collect() - torch.cuda.empty_cache() diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/base.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/base.py deleted file mode 100644 index bd4490c6a0d53e4267d3727335c48038e0702a1f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/base.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from abc import ABC, abstractmethod -from typing import Optional - -import torch - -from cosmos_transfer1.utils import callback -from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig -from cosmos_transfer1.utils.easy_io import easy_io -from cosmos_transfer1.utils.model import Model - - -class AbstractCheckpointer(ABC): - """The checkpointer class. Supports checkpoint saving/loading to local disk.""" - - def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): - """Constructor of the checkpointer. - - Args: - config_checkpoint (CheckpointConfig): The config object for the checkpointer. - """ - self.config_checkpoint = config_checkpoint - # Set the callback functions. - self.callbacks = callbacks - - # Set checkpoint directories for local paths - self._local_dirname = os.path.join(config_job.path_local, "checkpoints") - - self.strict_resume = config_checkpoint.strict_resume - self.load_path = config_checkpoint.load_path or None - self.load_training_state = config_checkpoint.load_training_state - self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state - self.save_thread = None - self.verbose = config_checkpoint.verbose - self.keys_not_to_resume = config_checkpoint.keys_not_to_resume - self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem - - @abstractmethod - def save( - self, - model: Model, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - pass - - @abstractmethod - def load( - self, - model: Model, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, - grad_scaler: Optional[torch.amp.GradScaler] = None, - ) -> int: - pass - - @property - def save_bucket(self): - """Get the bucket name for saving checkpoints.""" - return None - - @property - def load_bucket(self): - """Get the bucket name for loading checkpoints.""" - return None - - @property - def save_dirname(self): - return self._local_dirname - - @property - def load_dirname(self): - return self._local_dirname - - def finalize(self) -> None: - """Finalize the checkpointer.""" - if self.save_thread: - self.save_thread.join() - - def _read_latest_checkpoint_file(self) -> str | None: - """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. - - Returns: - checkpoint_file (str | None): file name of the latest saved checkpoint. - """ - checkpoint_file = None - checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt") - if easy_io.exists(checkpoint_path): - checkpoint_file = easy_io.load(checkpoint_path).strip() - - return checkpoint_file - - def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: - """Track the file name of the latest saved checkpoint. - - Args: - checkpoint_file (str): file name of the latest saved checkpoint. - """ - content = f"{checkpoint_file}\n" - checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt") - easy_io.dump(content, checkpoint_path) - - def _check_checkpoint_exists(self, checkpoint_path: str) -> None: - """If the file checkpoint_path does not exist, raise an error. - - Args: - checkpoint_path (str): full path to the checkpoint. - """ - if not easy_io.exists(checkpoint_path): - raise FileNotFoundError(f"File not found: {checkpoint_path}") diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/ddp_checkpointer.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/ddp_checkpointer.py deleted file mode 100644 index 4e230752f099822bd7747649278e43900f01a1f4..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/ddp_checkpointer.py +++ /dev/null @@ -1,437 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import threading -from collections import namedtuple -from typing import Any, Dict, Optional, Set, Tuple, Union - -import torch -import torch.distributed -from megatron.core import parallel_state -from torch.distributed import ProcessGroup, get_process_group_ranks - -from cosmos_transfer1.checkpointer.base import AbstractCheckpointer -from cosmos_transfer1.checkpointer.safe_broadcast import broadcast_object -from cosmos_transfer1.utils import distributed, log, misc -from cosmos_transfer1.utils.easy_io import easy_io -from cosmos_transfer1.utils.model import Model - -StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) - - -class Checkpointer(AbstractCheckpointer): - """ - Checkpointer for DDP. - Note: This implementation only supports local filesystem. - """ - - KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] - KEYS_TO_POSTFIX = { - "model": "model", - "optim": "optim", - "scheduler": "scheduler", - "trainer": "", - } - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() - ep_world_size = parallel_state.get_expert_model_parallel_world_size() - assert pp_world_size < 2, "Pipeline Parallelism (PP) is not tested yet." - assert ep_world_size < 2, "Expert Parallelism (EP) is not tested yet." - self.mp_world_size = parallel_state.get_model_parallel_group().size() - if self.mp_world_size > 1 and self.__class__ == Checkpointer: - raise NotImplementedError( - "Model Parallelism (MP) is enabled - " - "you should use TensorParallel Checkpointer instead of DDP Checkpointer." - ) - # DDP rank (with context parallelism considered) - self.rank_dp_w_cp = parallel_state.get_data_parallel_rank(with_context_parallel=True) - # Context parallelism rank - self.cp_rank = parallel_state.get_context_parallel_rank() - # Model parallelism rank (including Tensor+Pipeline+Expert Parallelisms) - self.mp_rank = parallel_state.get_model_parallel_group().rank() - # self.mp_rank = parallel_state.get_model_parallel_group(with_expert_parallel=ep_world_size > 1).rank() - if self.broadcast_via_filesystem: - log.info("Broadcasting checkpoint data via the local filesystem.") - if not self.strict_resume: - log.warning("Strict resume mode is off. Some model parameters may not be loaded.") - - # collect ranks of all model parallel groups - all_ranks = [None for _ in range(distributed.get_world_size())] - torch.distributed.all_gather_object( - all_ranks, get_process_group_ranks(parallel_state.get_model_parallel_group()) - ) - all_ranks = list(set(tuple(rank) if isinstance(rank, list) else rank for rank in all_ranks)) - for ranks in all_ranks: - group = torch.distributed.new_group(list(ranks), backend="gloo") - if distributed.get_rank() in ranks: - self.mp_gloo_pg = group - - self.print("Checkpointer Initialized.") - - def print(self, message: str): - """ - Print message to the console. Include the parallelism rank information when verbose is set to True. - """ - if self.verbose: - log.info( - f"[Parallelism Rank: DP-{self.rank_dp_w_cp}, TP-{self.mp_rank}, CP-{self.cp_rank}]: {message}", - rank0_only=False, - ) - else: - log.info(message, rank0_only=True) - - def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: - del model - assert key in self.KEYS_TO_SAVE - post_fix = self.KEYS_TO_POSTFIX[key] - - if post_fix: - _ckpt_path = checkpoint_path.replace(".pt", f"_{post_fix}.pt") - else: - _ckpt_path = checkpoint_path - return _ckpt_path - - def save( - self, - model: Model, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. - - Args: - model (Model): The PyTorch model. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - iteration (int): Current iteration number. - """ - self.callbacks.on_save_checkpoint_start(model, iteration) - - checkpoint_file = self.format_checkpoint_filename(model, iteration) - state_dict = self.generate_save_state_dict(model, optimizer, scheduler, grad_scaler, iteration) - state_dict = self._map_state_dict_path_during_save(state_dict, checkpoint_file, model) - if state_dict: - # Wait for previous saver thread to end. - if self.save_thread: - self.save_thread.join() - # Run the checkpoint saver in a separate thread. - self.save_thread = threading.Thread( - target=self._save_worker, - daemon=False, - args=(state_dict, checkpoint_file, distributed.get_rank()), - ) - self.save_thread.start() - - # Note: Checkpoints are saved on a separate thread and this callback is not accurate. - # Please check logs from on_save_checkpoint_success() for better accuracy - self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) - - def _map_state_dict_path_during_save(self, state_dict, checkpoint_file, model) -> dict[str, StateDictItemPath]: - new_dict = {} - for key, _state_dict in state_dict.items(): - _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_file, model) - checkpoint_path = os.path.join(self.save_dirname, _ckpt_path) - new_dict[key] = StateDictItemPath(_state_dict, checkpoint_path) - return new_dict - - @misc.timer("checkpoint saving") - def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: - """Worker to save checkpoint to disk, spawned with a child thread (in parallel with the training). - - Args: - state_dict (dict[str, StateDictItemPath]): The state dict of the model/optimizer/scheduler. - checkpoint_file (str): The file name of the model checkpoint. - rank (int): GPU device (default: 0). - """ - try: - for key, item in state_dict.items(): - self.print(f"Saving {key} to {item.save_path}") - try: - easy_io.dump( - item.state_dict, - item.save_path, - fast_backend=True, # optional for fast backend, cpu heavy - ) - abs_path = os.path.abspath(item.save_path) - self.print(f"Saved {key} to {item.save_path}, abspath = {abs_path}") - except Exception as e: - self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") - raise # Re-raise the exception after logging - - # Synchronize only rank 0 of each model parallel group - if self.mp_world_size > 1: - torch.distributed.barrier(group=self.mp_gloo_pg) - - # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt - if self.mp_rank == 0 and self.rank_dp_w_cp == 0: - self._write_latest_checkpoint_file(checkpoint_file) - - if distributed.get_rank() == 0: # only rank 0 saves trained_data_record - if "trained_data_record" in state_dict["model"].state_dict: - self._write_trained_data_record( - checkpoint_file, state_dict["model"].state_dict["trained_data_record"] - ) - - iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) - self.callbacks.on_save_checkpoint_success(iteration=iteration) - except Exception as e: # noqa: BLE001 - log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) - - def format_checkpoint_filename(self, model: Model, iteration: int) -> str: - """Generate the checkpoint file name. - - Args: - iteration (int): The current iteration number. - - Returns: - checkpoint_file (str): The checkpoint file name. - """ - del self, model - return f"iter_{iteration:09}.pt" - - @misc.timer("generate saving state dict") - def generate_save_state_dict( - self, - model: Model, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> Optional[Dict[str, Any]]: - state_dict = {} - - if self.rank_dp_w_cp == 0: - trainer_state = dict( - grad_scaler=grad_scaler.state_dict(), - iteration=iteration, - ) - model_state = model.state_dict() - optim_state = optimizer.state_dict() - scheduler_state = scheduler.state_dict() - self.callbacks.on_save_checkpoint(model, state_dict=trainer_state) - - trainer_state, model_state, optim_state, scheduler_state = misc.to( - [trainer_state, model_state, optim_state, scheduler_state], device="cpu" - ) - - state_dict = { - "model": model_state, - "optim": optim_state, - "scheduler": scheduler_state, - } - if distributed.get_rank() == 0: # only rank 0 saves trainer state - state_dict["trainer"] = trainer_state - return state_dict - return state_dict - - def load_broadcast_state_dict(self, checkpoint_path: str, model: Model, resume_keys: Set) -> dict[str, Any]: - """ - Load state_dict and broadcast. - - The main steps are: - 1. Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. - 2. Each rank loads its corresponding checkpoint from the local cache or receives it via broadcast. - - This approach ensures that each MP rank loads its specific part of the model, which is - crucial for Model Parallelism where different parts of the model are distributed across - multiple GPUs. - - When using Model Parallelism (e.g., Tensor Parallelism), the `broadcast_via_filesystem` option can - be set to True. This allows each rank to load its specific checkpoint from the local filesystem - instead of receiving it via network broadcast, which could be more efficient in some cases. - - For standard DDP without TP, `broadcast_via_filesystem` should remain False (default). - - Args: - checkpoint_path (str): The base path of the checkpoint. - model (Model): The model being loaded. - resume_keys (Set): Set of keys to resume from the checkpoint. - - Returns: - dict[str, Any]: A dictionary containing the loaded state for each resumed key. - """ - state_dict = {} - sorted_resume_keys = sorted(resume_keys) - # Step 1: Download TP-rank-specific checkpoints for every GPU of DDP-rank 0 and CP-rank 0. - if self.rank_dp_w_cp == 0: - for key in sorted_resume_keys: - _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) - local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) - if os.path.exists(local_cache_path): - # If the local checkpoint exists, we can directly load it - self.print(f"Checkpoint is already in local cache: {local_cache_path}. Loading...") - _state_dict = easy_io.load(local_cache_path, fast_backend=True) - else: - _state_dict = easy_io.load(_ckpt_path, fast_backend=True) - self.print(f"Downloading checkpoint from: {_ckpt_path}") - if self.broadcast_via_filesystem: - # Save the checkpoint to the local filesystem - easy_io.dump(_state_dict, local_cache_path, fast_backend=True) - state_dict[key] = _state_dict - # Ensure all ranks wait for the download to complete - distributed.barrier() - - # Step 2: Broadcast checkpoint data - log.info( - "Start broadcasting checkpoint from the source rank to all other ranks in the same DDP group.", - rank0_only=True, - ) - for key in sorted_resume_keys: - if self.broadcast_via_filesystem: - # Load the checkpoint from the local filesystem for other ranks - if self.rank_dp_w_cp != 0: - _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) - local_cache_path = os.path.join(self.load_dirname, os.path.basename(_ckpt_path)) - self.print(f"Loading checkpoint from: {local_cache_path}") - state_dict[key] = easy_io.load(local_cache_path, fast_backend=True) - else: - # Broadcast the checkpoint to all GPUs of the current DDP rank - group: ProcessGroup = parallel_state.get_data_parallel_group(with_context_parallel=True) - min_rank = min(get_process_group_ranks(group)) - - _state_dict = broadcast_object( - state_dict[key] if self.rank_dp_w_cp == 0 else None, - min_rank, - group=group, - device=torch.device(torch.cuda.current_device()), - ) - if self.rank_dp_w_cp == 0: - self.print(f'Broadcasted checkpoint["{key}"] to all other ranks in the same DDP group.') - else: - state_dict[key] = _state_dict - self.print(f'Received checkpoint["{key}"] from source rank {min_rank}.') - - return state_dict - - def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: - latest_checkpoint_file = self._read_latest_checkpoint_file() - - resume_keys = [] - - if latest_checkpoint_file is not None: - # 1. Resume training from latest_checkpoint.txt under the same name. - checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) - resume_keys.extend(self.KEYS_TO_SAVE) - else: - if self.load_path: - # 2. Load the module weights specified by config_checkpoint.path. - checkpoint_path = self.load_path - if self.load_training_state: - resume_keys.extend(self.KEYS_TO_SAVE) - else: - resume_keys.append("model") - if self.only_load_scheduler_state: - resume_keys.append("scheduler") - else: - checkpoint_path = None - if len(self.keys_not_to_resume) > 0: - for key in self.keys_not_to_resume: - assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" - resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] - return set(resume_keys), checkpoint_path - - @misc.timer("checkpoint loading") - def load( - self, - model: Model, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, - grad_scaler: torch.amp.GradScaler | None = None, - ) -> int: - """Load network weights and optimizer states from a checkpoint in a single process. - - The priority of the checkpoint loading logic is: - 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. - 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. - - This is typically used for inference mode. - - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. - 3. If none of the above, randomly initialize the model parameters and train from scratch. - - Args: - model (Model): The PyTorch model. - optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). - scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). - grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). - - Returns: - iteration (int): the iteration number to start/resume from. - """ - self.callbacks.on_load_checkpoint_start(model) - - resume_keys, checkpoint_path = self.keys_to_resume_during_load() - - iteration = 0 - - # Load checkpoint. - if checkpoint_path is not None: - self._check_checkpoint_exists(checkpoint_path) - state_dict = self.load_broadcast_state_dict(checkpoint_path, model, set(resume_keys)) - - if "trainer" in state_dict: - trainer_state = state_dict["trainer"] - log.critical(state_dict.keys(), rank0_only=False) - log.critical(trainer_state, rank0_only=False) - log.info("- Loading the gradient scaler...") - grad_scaler.load_state_dict(trainer_state["grad_scaler"]) - self.callbacks.on_load_checkpoint(model, state_dict=trainer_state) - iteration = trainer_state["iteration"] - if "optim" in state_dict: - assert optimizer - optimizer_state = state_dict["optim"] - log.info("- Loading the optimizer...") - optimizer.load_state_dict(optimizer_state) - if "scheduler" in state_dict: - assert scheduler - scheduler_state = state_dict["scheduler"] - log.info("- Loading the scheduler...") - scheduler.load_state_dict(scheduler_state) - scheduler.last_epoch = iteration - if "model" in state_dict: - model_state = state_dict["model"] - log.info("- Loading the model...") - # model.load_state_dict(model_state) - if self.strict_resume: - log.info("\t Strict resume mode is on.") - else: - log.info("\t Strict resume mode is off.") - model_load_info = model.load_state_dict(model_state, strict=self.strict_resume) - log.info(f"\t {model_load_info}") - self.print(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") - else: - log.info("Training from scratch.") - torch.cuda.empty_cache() - - self.callbacks.on_load_checkpoint_end(model) - - return iteration - - def _write_trained_data_record(self, checkpoint_file: str, trained_data_record: dict[str, int]) -> None: - """Write json file to save number of seen samples and number of iterations. - - Args: - checkpoint_file (str): iteration number for the saved checkpoint - trained_data_record (dict[str, int]): example {"image": 0, "video": 0, "iteration": 0}. - """ - # filename: iter_xxxxxxxxx_trained_data_record.json - checkpoint_path = os.path.join( - self.save_dirname, f"{checkpoint_file.replace('.pt', '')}_trained_data_record.json" - ) - easy_io.dump(trained_data_record, checkpoint_path) diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py deleted file mode 100644 index 5823743f69ff397b2fe39c3d849dff2ed45aea2a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/ema_fsdp_checkpointer.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -import attrs - -from cosmos_transfer1.checkpointer.fsdp_checkpointer import FSDPCheckpointer as BaseFSDPCheckpointer -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.config import CheckpointConfig as BaseCheckpointConfig -from cosmos_transfer1.utils.ddp_config import make_freezable - - -@make_freezable -@attrs.define(slots=False) -class CheckpointConfig(BaseCheckpointConfig): - load_ema_to_reg: bool = False - - -class FSDPCheckpointer(BaseFSDPCheckpointer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if not isinstance(self.config_checkpoint, CheckpointConfig): - warnings.warn( - "The 'config_checkpoint' is not an instance of 'CheckpointConfig'. " - "This behavior is deprecated and will not be supported in future versions. " - "Please update 'config_checkpoint' to be of type 'CheckpointConfig'.", - DeprecationWarning, - ) - - self.load_ema_to_reg = False - else: - self.load_ema_to_reg = self.config_checkpoint.load_ema_to_reg - - log.critical(f"load_ema_to_reg: {self.load_ema_to_reg}", rank0_only=False) - - def load_model_during_init(self, model, is_ema: bool = False, ema_id: int = 0): - if self.load_ema_to_reg and is_ema is False: - is_ema = True - ema_id = 0 - log.critical("Loading EMA model to regular model during initialization.", rank0_only=False) - super().load_model_during_init(model, is_ema, ema_id) diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/fast_tp.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/fast_tp.py deleted file mode 100644 index 499eff3f19d7625733f3ad34e3ee90b9916eea6d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/fast_tp.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Set - -import torch - -from cosmos_transfer1.checkpointer.ddp_checkpointer import StateDictItemPath -from cosmos_transfer1.checkpointer.tp_checkpointer import Checkpointer as TPCheckpointer -from cosmos_transfer1.diffusion.training.models.model import DiffusionModel -from cosmos_transfer1.utils import distributed, log, misc -from cosmos_transfer1.utils.easy_io import easy_io - - -class Checkpointer(TPCheckpointer): - def load_broadcast_state_dict( - self, checkpoint_path: str, model: DiffusionModel, resume_keys: Set - ) -> dict[str, Any]: - """ - Load state_dict and broadcast efficiently. - - This method optimizes checkpoint loading for distributed training for improved - connection speed and reliability. - - The main steps are: - 1. Retrieve TP-rank-specific checkpoints for each GPU of DDP-rank 0 - and CP-rank 0. - 2. Each rank loads its corresponding checkpoint either from a local cache or - receives it via broadcast. - - This approach ensures that each MP (Model Parallelism) rank loads its specific - part of the model, which is crucial for scenarios where different parts of the - model are distributed across multiple GPUs. - - The method supports both Tensor Parallelism (TP) and standard Data Parallel (DP) - training. For TP, each rank can efficiently load its specific checkpoint from S3. - For standard DDP without TP, the default broadcast mechanism is used. - - Args: - checkpoint_path (str): The base path of the checkpoint in S3. - model (DiffusionModel): The model being loaded. - resume_keys (Set): Set of keys to resume from the checkpoint. - - Returns: - dict[str, Any]: A dictionary containing the loaded state for each resumed key. - - Note: - This implementation has been tested and optimized for 4K GPU training jobs, - showing significant improvements in connection speed and overall efficiency. - """ - state_dict = {} - sorted_resume_keys = sorted(resume_keys) - for key in sorted_resume_keys: - _ckpt_path = self.add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) - _state_dict = easy_io.load(_ckpt_path, weights_only=False) - state_dict[key] = _state_dict - self.print(f"Loaded checkpoint from: {_ckpt_path}") - distributed.barrier() - return state_dict - - @misc.timer("checkpoint saving") - def _save_worker(self, state_dict: dict[str, StateDictItemPath], checkpoint_file: str, rank: int = 0) -> None: - """ - similar to the original _save_worker, but with the following changes: - * fast_backend=False to avoid high CPU usage - """ - try: - for key, item in state_dict.items(): - self.print(f"Saving {key} to {item.save_path}") - try: - easy_io.dump( - item.state_dict, - item.save_path, - # fast_backend=False, # too cpu heavy - ) - self.print(f"Saved {key} to {item.save_path}") - except Exception as e: - self.print(f"Failed to save {key} to {item.save_path}: {str(e)}") - raise # Re-raise the exception after logging - - # Synchronize only rank 0 of each model parallel group - if self.mp_world_size > 1: - torch.distributed.barrier(group=self.mp_gloo_pg) - - # Only rank 0 of MP group and rank 0 of DP with CP updates latest_checkpoint.txt - if self.mp_rank == 0 and self.rank_dp_w_cp == 0: - self._write_latest_checkpoint_file(checkpoint_file) - - if distributed.get_rank() == 0: # only rank 0 saves trained_data_record - if "trained_data_record" in state_dict["model"].state_dict: - self._write_trained_data_record( - checkpoint_file, state_dict["model"].state_dict["trained_data_record"] - ) - - iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) - self.callbacks.on_save_checkpoint_success(iteration=iteration) - except Exception as e: # noqa: BLE001 - log.exception(f"Checkpoint failed to upload: {e}", rank0_only=not self.verbose) diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/fsdp_checkpointer.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/fsdp_checkpointer.py deleted file mode 100644 index aa2295442a9eb5cd95ab74f095968d6192d35fde..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/fsdp_checkpointer.py +++ /dev/null @@ -1,392 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -import threading - -import torch -from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType - -from cosmos_transfer1.checkpointer.fsdp_optim_fix import scatter_full_optim_state_dict -from cosmos_transfer1.utils import callback, distributed, log, misc -from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig -from cosmos_transfer1.utils.model import Model - - -class FSDPCheckpointer: - """The checkpointer class. Supports checkpoint saving/loading to local disk.""" - - def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): - """Constructor of the checkpointer. - - Args: - config_checkpoint (CheckpointConfig): The config object for the checkpointer. - """ - # Set the callback functions. - self.callbacks = callbacks - self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" - self.strict_resume = config_checkpoint.strict_resume - self.load_path = config_checkpoint.load_path - self.load_training_state = config_checkpoint.load_training_state - self.save_thread = None - self.config_checkpoint = config_checkpoint - - def _load_ckpt_file_during_init(self): - latest_checkpoint_file = self._read_latest_checkpoint_file() - if latest_checkpoint_file is not None: - # 1. Resume training from latest_checkpoint.txt under the same name. - checkpoint_dir = self.checkpoint_dir_local - checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) - resume = True - log.critical(f"[Checkpoint] Found latest checkpoint file: {latest_checkpoint_file}") - log.critical(f"[Checkpoint] Loading from local path: {checkpoint_path}") - log.critical("[Checkpoint] Will resume full training state (model, optimizer, scheduler)") - else: - if self.load_path: - # 2. Load the module weights specified by config_checkpoint.path. - checkpoint_path = self.load_path - resume = self.load_training_state - log.critical(f"[Checkpoint] Using specified checkpoint path: {checkpoint_path}") - if resume: - log.critical("[Checkpoint] Will load complete training state (model, optimizer, scheduler)") - else: - log.critical("[Checkpoint] Will load model weights only (no optimizer/scheduler state)") - else: - # 3. Randomly initialize the model parameters and train from scratch. - checkpoint_path = None - resume = False - log.critical("[Checkpoint] No checkpoint path specified") - log.critical("[Checkpoint] Starting fresh training with random initialization") - return checkpoint_path, resume - - @misc.timer("FSDP.load_model_during_init") - def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): - if ema_id > 0: - assert is_ema, "ema_id should be used with is_ema=True" - checkpoint_path, _ = self._load_ckpt_file_during_init() - if checkpoint_path is not None: - tag = "reg" if not is_ema else "ema" - default_checkpoint_path = checkpoint_path.replace(".pt", f"_{tag}_model.pt") - if not os.path.exists(default_checkpoint_path): - default_checkpoint_path = checkpoint_path # starting from the release checkpoint - log.warning(f"is_ema={is_ema} model is not found. Loading from {default_checkpoint_path}") - if tag == "ema" and ema_id > 0: - _checkpoint_path = checkpoint_path.replace(".pt", f"_RANK{ema_id}.pt") - _checkpoint_path = _checkpoint_path.replace(".pt", f"_{tag}_model.pt") - if self._check_checkpoint_exists(_checkpoint_path, is_raise=False): - default_checkpoint_path = _checkpoint_path - else: - print( - f"{distributed.get_rank()}: Checkpoint not found: {_checkpoint_path} " - f"(fallback to {default_checkpoint_path})" - ) - checkpoint_path = default_checkpoint_path - self._check_checkpoint_exists(checkpoint_path) - - log.info(f"Loading checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) - log.success(f"Complete loading checkpoint (local): {checkpoint_path}") - log.info("- Loading the model...") - if self.strict_resume: - log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) - else: - log.critical("\t Using non-strict model") - from cosmos_transfer1.diffusion.inference.inference_utils import non_strict_load_model - - log.info(non_strict_load_model(model, state_dict)) - log.info("-finish model loading") - else: - log.info(f"is_ema={is_ema} model is not found and loaded.") - - @misc.timer("FSDP.load_optim_scheduler_during_init") - def load_optim_scheduler_during_init(self, fsdp_model, optimizer, scheduler): - checkpoint_path, resume = self._load_ckpt_file_during_init() - log.critical(f"Loading optimizer and scheduler: {checkpoint_path} (resume: {resume}") - if checkpoint_path is not None: - if resume: - checkpoint_path = checkpoint_path.replace(".pt", "_optim.pt") - self._check_checkpoint_exists(checkpoint_path) - if distributed.get_rank() == 0: - log.info(f"Loading checkpoint (local): {checkpoint_path}") - state_dict = torch.load( - checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False - ) - log.success(f"Complete loading checkpoint (local): {checkpoint_path}") - log.info("- Loading the optimizer (FSDP scatter)...") - else: - state_dict = { - "optimizer": None, - "scheduler": None, - } - distributed.barrier() - sharded_optimizer_state_dict = scatter_full_optim_state_dict( # <---- FSDP - state_dict["optimizer"], - fsdp_model, - ) - log.info("- Loading the optimizer (FSDP load_state_dict)...") - log.info(optimizer.load_state_dict(sharded_optimizer_state_dict)) - log.critical("Skip loading the scheduler...") - return - log.info("- Loading the scheduler...") - scheduler.load_state_dict(state_dict["scheduler"]) - - @misc.timer("FSDP get_optim_scheduler_state") - def get_optim_scheduler_state(self, optim, fsdp_model, scheduler): - with FSDP.state_dict_type( - fsdp_model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) - scheduler_statedict = scheduler.state_dict() - return { - "optimizer": optim_statedict, - "scheduler": scheduler_statedict, - } - - def save( - self, - model: Model, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. - - Args: - model (Model): The PyTorch model. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - iteration (int): Current iteration number. - """ - self.callbacks.on_save_checkpoint_start(model, iteration) - - model_state_dict = model.state_dict_model() - optim_scheduler_state_dict = self.get_optim_scheduler_state(optimizer, model.model, scheduler) - torch.cuda.empty_cache() - state_dict = dict( - iteration=iteration, - ) - self.callbacks.on_save_checkpoint(model, state_dict=state_dict) - - postfix, replicate_idx, shard_idx, total_ema_num = model.get_ckpt_postfix() - if replicate_idx == 0 and shard_idx == 0: - pass # save whole; it is rank0 - elif replicate_idx < total_ema_num and shard_idx == 0: - model_state_dict["model"] = None # only save ema - optim_scheduler_state_dict = None - state_dict = None - else: - return - - checkpoint_file = f"iter_{iteration:09}{postfix}.pt" - # Wait for previous saver thread to end. - if self.save_thread: - self.save_thread.join() - # Run the checkpoint saver in a separate thread. - self.save_thread = threading.Thread( - target=self._save_worker_local, - daemon=False, - args=(model_state_dict, optim_scheduler_state_dict, state_dict, checkpoint_file, distributed.get_rank()), - ) - self.save_thread.start() - - # Note: Checkpoints are saved on a separate thread and this callback is not accurate. - # Please check logs from on_save_checkpoint_success() for better accuracy - self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) - - @misc.timer("checkpoint saving (local)") - def _save_worker_local( - self, - model_state_dict: dict[str, torch.Tensor], - optim_scheduler_state_dict: dict[str, torch.Tensor], - state_dict: dict[str, torch.Tensor], - checkpoint_file: str, - rank: int = 0, - ) -> None: - """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). - - Args: - state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. - checkpoint_file (str): The file name of the model checkpoint. - rank (int): GPU device (default: 0). - """ - checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) - os.makedirs(self.checkpoint_dir_local, exist_ok=True) - try: - model_state_dict, ema_model_state_dict = model_state_dict["model"], model_state_dict["ema"] - if model_state_dict is not None: - torch.save(model_state_dict, checkpoint_path.replace(".pt", "_reg_model.pt")) - if ema_model_state_dict is not None: - torch.save(ema_model_state_dict, checkpoint_path.replace(".pt", "_ema_model.pt")) - if optim_scheduler_state_dict is not None: - torch.save(optim_scheduler_state_dict, checkpoint_path.replace(".pt", "_optim.pt")) - if state_dict is not None: - torch.save(state_dict, checkpoint_path) - if rank == 0: - self._write_latest_checkpoint_file(checkpoint_file) - log.success(f"Saved checkpoint (local): {checkpoint_path}") - iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) - self.callbacks.on_save_checkpoint_success(iteration=iteration) - except Exception as e: # noqa: BLE001 - log.exception(f"Checkpoint failed to save (local): {e}") - - @misc.timer("checkpoint loading") - def load( - self, - model: Model, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, - grad_scaler: torch.amp.GradScaler | None = None, - ) -> int: - """Load network weights and optimizer states from a checkpoint in a single process. - - The priority of the checkpoint loading logic is: - 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. - 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. - - This is typically used for inference mode. - - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. - 3. If none of the above, randomly initialize the model parameters and train from scratch. - - Args: - model (FSDPDiffModle): The PyTorch model. - optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). - scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). - grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). - - Returns: - iteration (int): the iteration number to start/resume from. - """ - self.callbacks.on_load_checkpoint_start(model) - - del optimizer, grad_scaler - checkpoint_path, resume = self._load_ckpt_file_during_init() - iteration = 0 - if checkpoint_path is not None: - self._check_checkpoint_exists(checkpoint_path) - log.info(f"Loading checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) - log.success(f"Complete loading checkpoint (local): {checkpoint_path}") - self.callbacks.on_load_checkpoint(model, state_dict=state_dict) - if resume: - iteration = state_dict["iteration"] - log.success("Done with loading the checkpoint.") - else: - log.info("Training from scratch.") - torch.cuda.empty_cache() - - self.callbacks.on_load_checkpoint_end(model) - - if scheduler is not None: - scheduler.last_epoch = iteration - log.critical(f"resume scheduler from {iteration}", rank0_only=False) - - return iteration - - def _read_latest_checkpoint_file(self) -> str | None: - """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. - - Returns: - checkpoint_file (str | None): file name of the latest saved checkpoint. - """ - checkpoint_file = None - latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") - if os.path.isfile(latest_path): - checkpoint_file = open(latest_path).read().strip() - if checkpoint_file is None: - log.warning(f"Latest ckpt file not found: {latest_path}") - else: - log.info(f"Found latest checkpoint: {checkpoint_file}") - return checkpoint_file - - def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: - """Track the file name of the latest saved checkpoint. - - Args: - checkpoint_file (str): file name of the latest saved checkpoint. - """ - content = f"{checkpoint_file}\n" - latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") - with open(latest_path, "w") as file: - file.write(content) - - def _check_checkpoint_exists(self, checkpoint_path: str, is_raise: bool = True) -> None: - """If the file checkpoint_path does not exist, raise an error. - - Args: - checkpoint_path (str): full path to the checkpoint. - """ - if not os.path.exists(checkpoint_path): - if is_raise: - raise FileNotFoundError(f"File not found (local): {checkpoint_path}") - return True - - def finalize(self) -> None: - """Finalize the checkpointer.""" - if self.save_thread: - self.save_thread.join() - - -class FSDPInferenceCheckpointer: - def __init__( - self, - ckpt_path: str, - strict_resume: bool = True, - ): - self.ckpt_path = ckpt_path - self.strict_resume = strict_resume - - @misc.timer("FSDPInferenceCheckpointer.load_model_during_init") - def load_model_during_init(self, model, is_ema=False, ema_id: int = 0): - del ema_id - if is_ema: - log.warning("EMA model is not supported in inference mode.") - return - assert os.path.exists(self.ckpt_path) - log.info(f"Loading from {self.ckpt_path}") - state_dict = torch.load(self.ckpt_path, map_location=lambda storage, loc: storage, weights_only=False) - if self.strict_resume: - log.info(model.load_state_dict(state_dict, strict=self.strict_resume)) - else: - log.critical("\t Using non-strict model") - from cosmos_transfer1.checkpointer.fsdp_checkpointer import non_strict_load_model - - log.info(non_strict_load_model(model, state_dict)) - log.info("-finish model loading") - - def load_optim_scheduler_during_init(self, *args, **kwargs): - """ - We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. - """ - pass - - def save(self, *args, **kwargs): - """ - We do not save anything in inference mode. The function is here to maintain the same interface to avoid errors. - """ - pass - - def load(self, *args, **kwargs): - """ - We do not do load in inference mode. The function is here to maintain the same interface to avoid errors. - """ - return 0 diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/fsdp_optim_fix.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/fsdp_optim_fix.py deleted file mode 100644 index a08aa943828d4c9c385d25873528edae0a84ec24..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/fsdp_optim_fix.py +++ /dev/null @@ -1,351 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# flake8: noqa -# isort: skip_file - -""" -torch 2.2 has bugs in loading optimizer states for FSDP in hybrid mode -torch impl uses state.rank and dist.rank() inconsistently -The file fix the bugs. Verified it works for hybrid mode and fullly sharded mode -Please use the `scatter_full_optim_state_dict` in the code to replace the corresponding function in torch 2.2 -""" - -import copy -import warnings -from typing import Any, Dict, Iterable, List, Optional, Union - -import torch -import torch.distributed as dist -from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel -from torch.distributed.fsdp._debug_utils import SimpleProfiler -from torch.distributed.fsdp._optim_utils import ( - _flatten_optim_state, - _FSDPState, - _get_fqn_to_fsdp_param_info, - _get_param_to_fqns, - _OptimStateKey, - _PosDimTensorInfo, - _shard_orig_param_state, - tree_map_only, -) -from torch.distributed.fsdp.fully_sharded_data_parallel import _rekey_sharded_optim_state_dict - - -def _broadcast_processed_state( - fsdp_state: _FSDPState, - optim_state: Dict[str, Any], - group: Optional[dist.ProcessGroup], -) -> Dict[str, Any]: - objects: List[Any] = [None] - if fsdp_state.rank == 0: - objects[0] = tree_map_only( - torch.Tensor, - lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), - optim_state, - ) - dist.broadcast_object_list(objects, src=0, group=group) - if dist.get_rank() == 0: - return optim_state - else: - return objects[0] - - -def _broadcast_state(fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]) -> Any: - if dist.get_rank() == 0: - if not isinstance(state, torch.Tensor) or state.dim() == 0: - return state - tensor = state.to(fsdp_state.compute_device) - else: - if isinstance(state, torch.Tensor): - assert state.dim() == 0, ( - "For non-zero ranks, a tensor state should have zero dimension, " - "but got the state with shape {state.shape()}." - ) - return state - elif not isinstance(state, _PosDimTensorInfo): - return state - tensor = torch.zeros(state.shape, dtype=state.dtype, device=fsdp_state.compute_device) - dist.broadcast(tensor, src=0, group=group) - return tensor - - -def _flatten_optim_state_dict( - optim_state_dict: Dict[str, Any], - model: nn.Module, - use_orig_params: bool = False, - optim: Optional[torch.optim.Optimizer] = None, - rank0_only: bool = False, - group: Optional[dist.ProcessGroup] = None, -) -> Dict[str, Any]: - """ - Flattens the full optimizer state dict, still keying by unflattened parameter - names. - - If ``use_orig_params`` is True, each rank will have all FSDP-managed - parameters but some of these parameters may be empty due to the sharding. - For a regular optim.Optimizer, states for those empty parameters will - not be initialized. So, when aggregating the FQNs across ranks, no assert - will be raised on a rank even if it does not have all the states -- it is - valid and FSDP know how to aggregate them. However, FSDP has to ignore - handling those parameters that are not managed by FSDP and do not exist on - the local rank -- it is managed by other parallelism and FSDP does not - know ho to handle/aggregate them. - - Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to - flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require - all the states even if the corresponding parameters are empty. To this end, - ``optim`` will be used to to get the initial state of the empty parameters. - ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or - NamedOptimizer. - - Returns: - Dict[str, Any]: The flattened optimizer state dict. - """ - SimpleProfiler.reset() - - unflat_osd = optim_state_dict - if "state" not in unflat_osd and not rank0_only: - raise ValueError('`optim_state_dict` must have the keys "state"' "to be a valid optimizer state dict") - param_to_fqns = _get_param_to_fqns(model) - fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) - fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state - - # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. - if rank0_only: - unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) - - # Construct the "state" part - flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} - unflat_osd_state = unflat_osd["state"] - all_state_keys = set(unflat_osd_state.keys()) - - for param, fqns in param_to_fqns.items(): - fqn = fqns[0] - if fqn not in unflat_osd_state: - continue - all_state_keys.difference_update(fqns) - - if rank0_only: - for fqn in fqns: - if not unflat_osd_state[fqn]: - continue - for state_name in unflat_osd_state[fqn].keys(): - unflat_osd_state[fqn][state_name] = _broadcast_state( - fsdp_state, unflat_osd_state[fqn][state_name], group=group - ) - fqn = fqns[0] - if fqn in fqn_to_fsdp_param_info: - fsdp_param_info = fqn_to_fsdp_param_info[fqn] - if use_orig_params: - with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): - flat_state = _shard_orig_param_state( - fsdp_param_info, - fqn, - unflat_osd_state[fqn], - ) - else: - flat_state = _flatten_optim_state( - fsdp_param_info, - unflat_osd_state, - fqns, - ) - key = _OptimStateKey(tuple(fqns), True) - # Only include non-empty states since as expected by - # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer - # or NamedOptimizer. - if flat_state: - flat_osd_state[key] = flat_state - elif use_orig_params: - assert len(fqns) == 1, f"use_orig_params is True but there are multiple FQNs, {fqns}." - if optim is not None: # NamedOptimizer or KeyedOptimizer case. - state = optim.state.get(param, None) # type: ignore[call-overload] - if state is not None: - flat_osd_state[key] = copy.deepcopy(state) - else: - warnings.warn(f"optim_state[{key}] is not on rank{fsdp_state.rank}.") - - else: - raise RuntimeError(f"The state of {key} is empty. This should happen when " "use_orig_params=True.") - else: # do not flatten non-FSDP parameters' states - assert len(fqns) == 1 - key = _OptimStateKey(tuple(fqns), False) - flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) - - if rank0_only: - for fqn in fqns: - if not unflat_osd_state[fqn]: - continue - for state_name, param_state in list(unflat_osd_state[fqn].items()): - if fsdp_state.rank > 0: - # Deference the tensor so that PyTorch can collect the memory. - del unflat_osd_state[fqn][state_name] - else: - # Move the tensor in the original osd back to CPU to make the - # original osd unaffected. - unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][state_name].cpu() - - # Handle user-defined state, states that are not associated with parameters. - for key in all_state_keys: - user_state = unflat_osd_state[key] - if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: - user_state = _broadcast_state(fsdp_state, user_state, group=group) - flat_osd_state[key] = copy.copy(user_state) - - SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") - # Construct the "param_groups" part -- copy as is since it will be - # rekeyed later according to the target rank's optimizer - # Only copy param_groups if it exists in unflat_osd - if "param_groups" in unflat_osd: - flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) - return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} - else: - return {"state": flat_osd_state} - - -def _optim_state_dict_to_load_impl( - optim_state_dict: Dict[str, Any], - model: torch.nn.Module, - optim_input: Optional[ - Union[ - List[Dict[str, Any]], - Iterable[torch.nn.Parameter], - ] - ] = None, - optim: Optional[torch.optim.Optimizer] = None, - full_state_dict: bool = True, - rank0_only: bool = False, - is_named_optimizer: bool = False, - group: Optional[dist.ProcessGroup] = None, -) -> Dict[str, Any]: - """ - The internal API that is used by all the load optim_state_dict implementations. - Given model, optim, and the saved optim_state_dict, this API adds the FSDP - internal information and internal sharding to the optim_state_dict. - """ - if full_state_dict: - FullyShardedDataParallel._warn_optim_input(optim_input) - using_optim_input = FullyShardedDataParallel._is_using_optim_input( - optim_input, - optim, - ) - else: - using_optim_input = False - assert optim_input is None and not rank0_only - - use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[0]._use_orig_params - assert all( - use_orig_params == m._use_orig_params for m in FullyShardedDataParallel.fsdp_modules(model) - ), "Not all FSDP modules have the same _use_orig_params value" - - if rank0_only and dist.get_rank(group) > 0: - optim_state_dict = {} - sharded_osd = _flatten_optim_state_dict( - optim_state_dict, - model=model, - use_orig_params=use_orig_params, - optim=(optim if is_named_optimizer else None), - rank0_only=rank0_only, - group=group, - ) - return _rekey_sharded_optim_state_dict( - sharded_osd, - model=model, - optim=optim, - optim_input=optim_input, - using_optim_input=using_optim_input, - is_named_optimizer=is_named_optimizer, - ) - - -def scatter_full_optim_state_dict( - full_optim_state_dict: Optional[Dict[str, Any]], - model: torch.nn.Module, - optim_input: Optional[ - Union[ - List[Dict[str, Any]], - Iterable[torch.nn.Parameter], - ] - ] = None, - optim: Optional[torch.optim.Optimizer] = None, - group: Optional[Any] = None, -) -> Dict[str, Any]: - """ - Scatters the full optimizer state dict from rank 0 to all other ranks, - returning the sharded optimizer state dict on each rank. The return - value is the same as :meth:`shard_full_optim_state_dict`, and on rank - 0, the first argument should be the return value of - :meth:`full_optim_state_dict`. - - Example:: - - >>> # xdoctest: +SKIP("undefined variables") - >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - >>> model, optim = ... - >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 - >>> # Define new model with possibly different world size - >>> new_model, new_optim, new_group = ... - >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) - >>> new_optim.load_state_dict(sharded_osd) - - .. note:: Both :meth:`shard_full_optim_state_dict` and - :meth:`scatter_full_optim_state_dict` may be used to get the - sharded optimizer state dict to load. Assuming that the full - optimizer state dict resides in CPU memory, the former requires - each rank to have the full dict in CPU memory, where each rank - individually shards the dict without any communication, while the - latter requires only rank 0 to have the full dict in CPU memory, - where rank 0 moves each shard to GPU memory (for NCCL) and - communicates it to ranks appropriately. Hence, the former has - higher aggregate CPU memory cost, while the latter has higher - communication cost. - - Args: - full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state - dict corresponding to the unflattened parameters and holding - the full non-sharded optimizer state if on rank 0; the argument - is ignored on nonzero ranks. - model (torch.nn.Module): Root module (which may or may not be a - :class:`FullyShardedDataParallel` instance) whose parameters - correspond to the optimizer state in ``full_optim_state_dict``. - optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): - Input passed into the optimizer representing either a - :class:`list` of parameter groups or an iterable of parameters; - if ``None``, then this method assumes the input was - ``model.parameters()``. This argument is deprecated, and there - is no need to pass it in anymore. (Default: ``None``) - optim (Optional[torch.optim.Optimizer]): Optimizer that will load - the state dict returned by this method. This is the preferred - argument to use over ``optim_input``. (Default: ``None``) - group (dist.ProcessGroup): Model's process group or ``None`` if - using the default process group. (Default: ``None``) - - Returns: - Dict[str, Any]: The full optimizer state dict now remapped to - flattened parameters instead of unflattened parameters and - restricted to only include this rank's part of the optimizer state. - """ - FullyShardedDataParallel._warn_legacy_optim_state_dict("scatter_full_optim_state_dict", "optim_state_dict_to_load") - return _optim_state_dict_to_load_impl( - optim_state_dict=full_optim_state_dict, - model=model, - optim_input=optim_input, - optim=optim, - full_state_dict=True, - rank0_only=True, - is_named_optimizer=False, - group=group, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/multi_rank_checkpointer.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/multi_rank_checkpointer.py deleted file mode 100644 index d6408c91ae40663befb1d7f67a1c379c2422d4af..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/multi_rank_checkpointer.py +++ /dev/null @@ -1,236 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -import threading -from typing import List, NamedTuple, Tuple - -import torch - -from cosmos_transfer1.utils import distributed, log, misc -from cosmos_transfer1.utils.checkpointer import Checkpointer as BaseCheckpointer -from cosmos_transfer1.utils.model import Model - -TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) -if TORCH_VERSION >= (1, 11): - from torch.ao import quantization - from torch.ao.quantization import FakeQuantizeBase, ObserverBase -elif ( - TORCH_VERSION >= (1, 8) - and hasattr(torch.quantization, "FakeQuantizeBase") - and hasattr(torch.quantization, "ObserverBase") -): - from torch import quantization - from torch.quantization import FakeQuantizeBase, ObserverBase - - -class _IncompatibleKeys( - NamedTuple( - "IncompatibleKeys", - [ - ("missing_keys", List[str]), - ("unexpected_keys", List[str]), - ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), - ], - ) -): - pass - - -class MultiRankCheckpointer(BaseCheckpointer): - def save( - self, - model: Model, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. - - Args: - model (Model): The PyTorch model. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - iteration (int): Current iteration number. - """ - # checkpoint_file = f"iter_{iteration:09}.pt" - postfix, _, total_ema_num = model.get_ckpt_postfix() - checkpoint_file = f"iter_{iteration:09}{postfix}.pt" - save_ranks = list(range(total_ema_num)) - for _rank in save_ranks: - if distributed.get_rank() == _rank: - state_dict = dict( - model=model.state_dict(), - optimizer=optimizer.state_dict(), - scheduler=scheduler.state_dict(), - grad_scaler=grad_scaler.state_dict(), - iteration=iteration, - ) - state_dict = misc.to(state_dict, device="cpu") - self.callbacks.on_save_checkpoint(model, state_dict=state_dict) - # Wait for previous saver thread to end. - if self.save_thread: - self.save_thread.join() - # Run the checkpoint saver in a separate thread. - self.save_thread = threading.Thread( - target=self._save_worker_local, - daemon=False, - args=(state_dict, checkpoint_file, distributed.get_rank()), - ) - self.save_thread.start() - - @misc.timer("checkpoint loading") - def load( - self, - model: Model, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, - grad_scaler: torch.amp.GradScaler | None = None, - ) -> int: - """Load network weights and optimizer states from a checkpoint in a single process. - - The priority of the checkpoint loading logic is: - 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. - 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. - - This is typically used for inference mode. - - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. - 3. If none of the above, randomly initialize the model parameters and train from scratch. - - Args: - model (Model): The PyTorch model. - optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). - scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). - grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). - - Returns: - iteration (int): the iteration number to start/resume from. - """ - latest_checkpoint_file = self._read_latest_checkpoint_file() - if latest_checkpoint_file is not None: - # different from base checkpointer, this support multi-EMA - postfix, _, total_ema_num = model.get_ckpt_postfix() - latest_checkpoint_file = latest_checkpoint_file.replace(".pt", f"{postfix}.pt") - # 1. Resume training from latest_checkpoint.txt under the same name. - checkpoint_dir = self.checkpoint_dir_local - checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) - resume = True - else: - if self.load_path: - # 2. Load the module weights specified by config_checkpoint.path. - checkpoint_path = self.load_path - # different from base checkpointer, this support multi-EMA - postfix, _, total_ema_num = model.get_ckpt_postfix() - checkpoint_path = checkpoint_path.replace(".pt", f"{postfix}.pt") - resume = self.load_training_state - else: - # 3. Randomly initialize the model parameters and train from scratch. - checkpoint_path = None - resume = False - # Load checkpoint. - if checkpoint_path is not None: - self._check_checkpoint_exists(checkpoint_path) - log.info(f"Loading checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - log.success(f"Complete loading checkpoint (local): {checkpoint_path}") - self.callbacks.on_load_checkpoint(model, state_dict=state_dict) - # Load the state dicts. - log.info("- Loading the model...") - log.critical(model.load_state_dict(state_dict["model"], strict=self.strict_resume)) - if resume: - iteration = state_dict["iteration"] - assert optimizer and scheduler - log.info("- Loading the optimizer...") - optimizer.load_state_dict(state_dict["optimizer"]) - log.info("- Loading the scheduler...") - scheduler.load_state_dict(state_dict["scheduler"]) - scheduler.last_epoch = iteration - log.info("- Loading the gradient scaler...") - grad_scaler.load_state_dict(state_dict["grad_scaler"]) - log.success(f"Done with loading the checkpoint (iteration {iteration}).") - else: - iteration = 0 - log.success("Done with loading the checkpoint.") - else: - # Checkpoint not found and not specified. We will train everything from scratch. - iteration = 0 - log.info("Training from scratch.") - torch.cuda.empty_cache() - return iteration - - -# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py -def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: - # workaround https://github.com/pytorch/pytorch/issues/24139 - model_state_dict = model.state_dict() - incorrect_shapes = [] - for k in list(checkpoint_state_dict.keys()): - if k in model_state_dict: - if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 - log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") - continue - model_param = model_state_dict[k] - # Allow mismatch for uninitialized parameters - if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): - continue - if not isinstance(model_param, torch.Tensor): - raise ValueError( - f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." - ) - - shape_model = tuple(model_param.shape) - shape_checkpoint = tuple(checkpoint_state_dict[k].shape) - if shape_model != shape_checkpoint: - has_observer_base_classes = ( - TORCH_VERSION >= (1, 8) - and hasattr(quantization, "ObserverBase") - and hasattr(quantization, "FakeQuantizeBase") - ) - if has_observer_base_classes: - # Handle the special case of quantization per channel observers, - # where buffer shape mismatches are expected. - def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: - # foo.bar.param_or_buffer_name -> [foo, bar] - key_parts = key.split(".")[:-1] - cur_module = model - for key_part in key_parts: - cur_module = getattr(cur_module, key_part) - return cur_module - - cls_to_skip = ( - ObserverBase, - FakeQuantizeBase, - ) - target_module = _get_module_for_key(model, k) - if isinstance(target_module, cls_to_skip): - # Do not remove modules with expected shape mismatches - # them from the state_dict loading. They have special logic - # in _load_from_state_dict to handle the mismatches. - continue - - incorrect_shapes.append((k, shape_checkpoint, shape_model)) - checkpoint_state_dict.pop(k) - incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) - # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling - missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] - unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] - return _IncompatibleKeys( - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - incorrect_shapes=incorrect_shapes, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/safe_broadcast.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/safe_broadcast.py deleted file mode 100644 index f914299c97f297cf43a5919b0bc8130686afa7c1..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/safe_broadcast.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import io -import pickle -from typing import Any - -import torch -import torch.distributed as dist - - -# https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/zero_redundancy_optimizer.py#L29 -def broadcast_object( - obj: Any, - src_rank: int, - group: object = dist.group.WORLD, - device: torch.device = torch.device("cpu"), -) -> Any: - r""" - Broadcasts an object to the given group. - - It will be sending the object if called from the source rank and receiving - the object otherwise. - - Arguments: - obj: object to broadcast; only used if called on the source rank. - src_rank (int): source rank. - group (``ProcessGroup``, optional): group used for the broadcast - (default: ``dist.group.WORLD``). - device (``torch.device``, optional): device to send from or receive - to (default: ``torch.device("cpu")``). - - Returns: - The broadcasted object. - """ - if dist.get_rank() == src_rank: - # Send the object - buffer = io.BytesIO() - torch.save(obj, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.LongTensor([len(data)]).to(device) - data_send_tensor = torch.ByteTensor(data).to(device) - dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) - dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) - else: - # Receive the object - length_tensor = torch.LongTensor([0]).to(device) - dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) - data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device) - dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) - buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) - obj = torch.load(buffer, map_location=device, weights_only=False) - return obj - - -def _recursive_copy_to_device( - value: Any, - non_blocking: bool, - device: torch.device, -) -> Any: - r""" - Recursively searches lists, tuples, dicts and copies tensors to device if possible. - - Non-tensor values are passed as-is in the result. - - .. note: These are all copies, so if there are two objects that reference - the same object, then after this call, there will be two different objects - referenced on the device. - """ - if isinstance(value, torch.Tensor): - return value.to(device, non_blocking=non_blocking) - - if isinstance(value, (list, tuple)): - values = [_recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for val in value] - return values if isinstance(value, list) else tuple(values) - - if isinstance(value, collections.abc.Mapping): - return { - key: _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) for key, val in value.items() - } - - return value diff --git a/cosmos-transfer1/cosmos_transfer1/checkpointer/tp_checkpointer.py b/cosmos-transfer1/cosmos_transfer1/checkpointer/tp_checkpointer.py deleted file mode 100644 index 0420857e1b99ccd8e4881f7572339ebce6ba70a8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpointer/tp_checkpointer.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.checkpointer.ddp_checkpointer import Checkpointer as DDPCheckpointer -from cosmos_transfer1.utils.model import Model - - -class Checkpointer(DDPCheckpointer): - """ - Checkpointer class for Tensor Parallelism (TP) in distributed training. - - This implementation supports the combination of Tensor Parallelism (TP) and Data Parallel Processing (DDP), with optional Context Parallelism (CP). - - Note: - - Fully Sharded Data Parallelism (FSDP) is not supported by this checkpointer. - - In principle, this implementation is also compatible with Pipeline Parallelism (PP) and Expert Parallelism (EP), which are other forms of model parallelism. However, PP and EP have not been tested yet. - """ - - def add_type_postfix_to_checkpoint_path(self, key: str, checkpoint_path: str, model: Model) -> str: - """ - Overwrite the `add_type_postfix_to_checkpoint_path` function of the base class (DDP checkpointer) - to append the TP-rank postfix to the checkpoint path. - """ - checkpoint_path = super().add_type_postfix_to_checkpoint_path(key, checkpoint_path, model) - if key == "trainer": - return checkpoint_path - else: - checkpoint_path = checkpoint_path.replace(".pt", f"_mp_{self.mp_rank}.pt") - - return checkpoint_path diff --git a/cosmos-transfer1/cosmos_transfer1/checkpoints.py b/cosmos-transfer1/cosmos_transfer1/checkpoints.py deleted file mode 100644 index 713a3afc364c25bd564b857fef1d3fda59cbfe7c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/checkpoints.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Cosmos Transfer1 Model Checkpoints -COSMOS_TRANSFER1_7B_CHECKPOINT = "nvidia/Cosmos-Transfer1-7B" -COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT = "nvidia/Cosmos-Transfer1-7B-Sample-AV" -COSMOS_TOKENIZER_CHECKPOINT = "nvidia/Cosmos-Tokenize1-CV8x8x8-720p" -COSMOS_UPSAMPLER_CHECKPOINT = "nvidia/Cosmos-UpsamplePrompt1-12B-Transfer" -COSMOS_GUARDRAIL_CHECKPOINT = "nvidia/Cosmos-Guardrail1" - -# 3rd Party Model Checkpoints -SAM2_MODEL_CHECKPOINT = "facebook/sam2-hiera-large" -DEPTH_ANYTHING_MODEL_CHECKPOINT = "depth-anything/Depth-Anything-V2-Small-hf" -GROUNDING_DINO_MODEL_CHECKPOINT = "IDEA-Research/grounding-dino-tiny" -T5_MODEL_CHECKPOINT = "google-t5/t5-11b" -LLAMA_GUARD_3_MODEL_CHECKPOINT = "meta-llama/Llama-Guard-3-8B" - -# Internal Checkpoint Paths, please append _PATH to the end of the variable -BASE_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/base_model.pt" -VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/vis_control.pt" -EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/edge_control.pt" -SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/seg_control.pt" -DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/depth_control.pt" -KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/keypoint_control.pt" -UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/4kupscaler_control.pt" -BASE_7B_CHECKPOINT_AV_SAMPLE_PATH = f"{COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT}/base_model.pt" -HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT}/hdmap_control.pt" -LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT}/lidar_control.pt" - -# Transfer1-7B-SV2MV-Sample-AV checkpoints - -COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT = "nvidia/Cosmos-Transfer1-7B-Sample-AV-Single2MultiView" -BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH = f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/t2w_base_model.pt" -BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH = f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/v2w_base_model.pt" -SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = ( - f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/t2w_hdmap_control.pt" -) -SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = ( - f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/t2w_lidar_control.pt" -) -SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = ( - f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/v2w_hdmap_control.pt" -) -SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH = ( - f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/v2w_lidar_control.pt" -) - -SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH = ( - f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/t2w_hdmap_control_waymo.pt" -) -SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH = ( - f"{COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT}/v2w_hdmap_control_waymo.pt" -) -EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH = f"{COSMOS_TRANSFER1_7B_CHECKPOINT}/edge_control_distilled.pt" diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/conditioner.py b/cosmos-transfer1/cosmos_transfer1/diffusion/conditioner.py deleted file mode 100644 index a9946ec7a80c8997e79a91e8ec960ba0e5d4505a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/conditioner.py +++ /dev/null @@ -1,418 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from abc import ABC, abstractmethod -from collections import defaultdict -from dataclasses import dataclass, fields -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn - -from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.lazy_config import instantiate - - -class BaseConditionEntry(nn.Module): - def __init__(self): - super().__init__() - - self._dropout_rate = None - self._input_key = None - self._return_dict = False - - @property - def dropout_rate(self) -> Union[float, torch.Tensor]: - return self._dropout_rate - - @property - def input_key(self) -> str: - return self._input_key - - @property - def is_return_dict(self) -> bool: - return self._return_dict - - @dropout_rate.setter - def dropout_rate(self, value: Union[float, torch.Tensor]): - self._dropout_rate = value - - @input_key.setter - def input_key(self, value: str): - self._input_key = value - - @is_return_dict.setter - def is_return_dict(self, value: bool): - self._return_dict = value - - @dropout_rate.deleter - def dropout_rate(self): - del self._dropout_rate - - @input_key.deleter - def input_key(self): - del self._input_key - - @is_return_dict.deleter - def is_return_dict(self): - del self._return_dict - - def random_dropout_input( - self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None - ) -> torch.Tensor: - del key - dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate - return batch_mul( - torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), - in_tensor, - ) - - def summary(self) -> str: - pass - - -class DataType(Enum): - IMAGE = "image" - VIDEO = "video" - - -class TextAttr(BaseConditionEntry): - def __init__(self): - super().__init__() - - def forward(self, token: torch.Tensor, mask: torch.Tensor): - return {"crossattn_emb": token, "crossattn_mask": mask} - - def random_dropout_input( - self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None - ) -> torch.Tensor: - if key is not None and "mask" in key: - return in_tensor - return super().random_dropout_input(in_tensor, dropout_rate, key) - - -class FrameRepeatAttr(BaseConditionEntry): - def __init__(self): - super().__init__() - - def forward(self, frame_repeat: torch.Tensor) -> Dict[str, torch.Tensor]: - return { - "frame_repeat": frame_repeat / 10.0, - } - - def details(self) -> str: - return "Frame repeat, Output key: [frame_repeat]" - - -@dataclass -class BaseVideoCondition: - crossattn_emb: torch.Tensor - crossattn_mask: torch.Tensor - data_type: DataType = DataType.VIDEO - padding_mask: Optional[torch.Tensor] = None - fps: Optional[torch.Tensor] = None - num_frames: Optional[torch.Tensor] = None - image_size: Optional[torch.Tensor] = None - scalar_feature: Optional[torch.Tensor] = None - frame_repeat: Optional[torch.Tensor] = None - regional_contexts: Optional[torch.Tensor] = None - region_masks: Optional[torch.Tensor] = None - - def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: - return {f.name: getattr(self, f.name) for f in fields(self)} - - -@dataclass -class VideoExtendCondition(BaseVideoCondition): - video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video - gt_latent: Optional[torch.Tensor] = None - condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region - - # condition_video_input_mask will concat to the input of network, along channel dim; - # Will be concat with the input tensor - condition_video_input_mask: Optional[torch.Tensor] = None - # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" - condition_video_augment_sigma: Optional[torch.Tensor] = None - # pose conditional input, will be concat with the input tensor - condition_video_pose: Optional[torch.Tensor] = None - - -@dataclass -class ViewConditionedVideoExtendCondition(VideoExtendCondition): - # view index indicating camera, used to index nn.Embedding - view_indices_B_T: Optional[torch.Tensor] = None - # number of cameras in this cond data - data_n_cameras: Optional[int] = -1 - - -class GeneralConditioner(nn.Module, ABC): - """ - An abstract module designed to handle various embedding models with conditional and - unconditional configurations. This abstract base class initializes and manages a collection - of embedders that can dynamically adjust their dropout rates based on conditioning. - - Attributes: - KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. - embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and - configured based on the provided configurations. - - Parameters: - emb_models (Union[List, Any]): A dictionary where keys are embedder names and values - are configurations for initializing the embedders. - - """ - - KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1} - - def __init__(self, **emb_models: Union[List, Any]): - super().__init__() - self.embedders = nn.ModuleDict() - for n, (emb_name, embconfig) in enumerate(emb_models.items()): - embedder = instantiate(embconfig.obj) - assert isinstance( - embedder, BaseConditionEntry - ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" - embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0) - - if hasattr(embconfig, "input_key"): - embedder.input_key = embconfig.input_key - elif hasattr(embconfig, "input_keys"): - embedder.input_keys = embconfig.input_keys - else: - raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") - - log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") - self.embedders[emb_name] = embedder - - @abstractmethod - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> Any: - """Should be implemented in subclasses to handle conditon datatype""" - raise NotImplementedError - - def _forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> Dict: - """ - Processes the input batch through all configured embedders, applying conditional dropout rates if specified. - Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. - - Parameters: - batch (Dict): The input data batch to process. - override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates - per embedder key. - - Returns: - Dict: A dictionary of output tensors concatenated by specified dimensions. - - Note: - In case the network code is sensitive to the order of concatenation, you can either control the order via \ - config file or make sure the embedders return a unique key for each output. - """ - output = defaultdict(list) - if override_dropout_rate is None: - override_dropout_rate = {} - - # make sure emb_name in override_dropout_rate is valid - for emb_name in override_dropout_rate.keys(): - assert emb_name in self.embedders, f"invalid name found {emb_name}" - - for emb_name, embedder in self.embedders.items(): - with torch.no_grad(): - if hasattr(embedder, "input_key") and (embedder.input_key is not None): - emb_out = embedder( - embedder.random_dropout_input( - batch[embedder.input_key], override_dropout_rate.get(emb_name, None) - ) - ) - elif hasattr(embedder, "input_keys"): - emb_out = embedder( - *[ - embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k) - for k in embedder.input_keys - ] - ) - for k, v in emb_out.items(): - output[k].append(v) - # Concatenate the outputs - return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} - - def get_condition_uncondition( - self, - data_batch: Dict, - ) -> Tuple[Any, Any]: - """ - Processes the provided data batch to generate conditioned and unconditioned outputs. - - This method manipulates dropout rates to simulate two scenarios: - 1. All conditions applied (conditioned) - 2. Conditions removed/reduced to minimum (unconditioned) - - This method sets dropout rates to zero for the conditioned scenario to fully apply - embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is - insignificant) to minimize embedder influences. - - Parameters: - data_batch (Dict): Input data batch containing all necessary information for - embedding processing. - - Returns: - Tuple[Any, Any]: A tuple containing: - - Outputs with all embedders fully applied (conditioned) - - Outputs with embedders minimized/not applied (unconditioned) - """ - cond_dropout_rates, dropout_rates = {}, {} - for emb_name, embedder in self.embedders.items(): - cond_dropout_rates[emb_name] = 0.0 - dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 - - condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) - un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) - return condition, un_condition - - def get_condition_with_negative_prompt( - self, - data_batch: Dict, - ) -> Tuple[Any, Any]: - """ - Similar functionality as get_condition_uncondition - But use negative prompts for unconditon - """ - cond_dropout_rates, uncond_dropout_rates = {}, {} - for emb_name, embedder in self.embedders.items(): - cond_dropout_rates[emb_name] = 0.0 - if isinstance(embedder, TextAttr): - uncond_dropout_rates[emb_name] = 0.0 - else: - uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 - - data_batch_neg_prompt = copy.deepcopy(data_batch) - if "neg_t5_text_embeddings" in data_batch_neg_prompt: - if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): - data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] - data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"] - - condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) - un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) - - return condition, un_condition - - -@dataclass -class CosmosCondition: - crossattn_emb: torch.Tensor - crossattn_mask: torch.Tensor - padding_mask: Optional[torch.Tensor] = None - scalar_feature: Optional[torch.Tensor] = None - - def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: - return {f.name: getattr(self, f.name) for f in fields(self)} - - -class VideoConditioner(GeneralConditioner): - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> BaseVideoCondition: - output = super()._forward(batch, override_dropout_rate) - if "regional_contexts" in batch: - output["regional_contexts"] = batch["regional_contexts"] - if "region_masks" in batch: - output["region_masks"] = batch["region_masks"] - return BaseVideoCondition(**output) - - -class VideoExtendConditioner(GeneralConditioner): - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> VideoExtendCondition: - output = super()._forward(batch, override_dropout_rate) - return VideoExtendCondition(**output) - - -class ViewConditionedVideoExtendConditioner(GeneralConditioner): - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> ViewConditionedVideoExtendCondition: - output = super()._forward(batch, override_dropout_rate) - return ViewConditionedVideoExtendCondition(**output) - - -@dataclass -class BaseWithCtrlCondition(VideoExtendCondition): - control_input_edge: Optional[torch.Tensor] = None - control_input_vis: Optional[torch.Tensor] = None - control_input_depth: Optional[torch.Tensor] = None - control_input_seg: Optional[torch.Tensor] = None - control_input_keypoint: Optional[torch.Tensor] = None - control_input_upscale: Optional[torch.Tensor] = None - control_input_hdmap: Optional[torch.Tensor] = None - control_input_lidar: Optional[torch.Tensor] = None - control_input_multi: Optional[torch.Tensor] = None - base_model: Optional[torch.nn.Module] = None - hint_key: Optional[str] = None - control_weight: Optional[float] = 1.0 - num_layers_to_use: Optional[int] = -1 - - -@dataclass -class ViewConditionedWithCtrlCondition(BaseWithCtrlCondition): - # view index indicating camera, used to index nn.Embedding - view_indices_B_T: Optional[torch.Tensor] = None - # number of cameras in this cond data - data_n_views: Optional[int] = -1 - - -class VideoConditionerWithCtrl(VideoExtendConditioner): - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> BaseWithCtrlCondition: - output = super()._forward(batch, override_dropout_rate) - output["hint_key"] = batch["hint_key"] - if "control_weight" in batch: - output["control_weight"] = batch["control_weight"] - if "num_layers_to_use" in batch: - output["num_layers_to_use"] = batch["num_layers_to_use"] - return BaseWithCtrlCondition(**output) - - -class ViewConditionedVideoConditionerWithCtrl(VideoConditionerWithCtrl): - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> ViewConditionedWithCtrlCondition: - output = super()._forward(batch, override_dropout_rate) - output["hint_key"] = batch["hint_key"] - if "control_weight" in batch: - output["control_weight"] = batch["control_weight"] - if "num_layers_to_use" in batch: - output["num_layers_to_use"] = batch["num_layers_to_use"] - return ViewConditionedWithCtrlCondition(**output) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/conditioner.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/conditioner.py deleted file mode 100644 index d617636ded626dac1af70b06621d65e141cc88de..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/conditioner.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional - -import attrs -import torch - -from cosmos_transfer1.diffusion.conditioner import ( - BaseConditionEntry, - TextAttr, - VideoConditioner, - VideoExtendConditioner, -) -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - - -@attrs.define(slots=False) -class TextConfig: - obj: LazyDict = L(TextAttr)() # No arguments - dropout_rate: float = 0.2 - input_keys: List[str] = attrs.field(factory=lambda: ["t5_text_embeddings", "t5_text_mask"]) - - -class BooleanFlag(BaseConditionEntry): - def __init__(self, output_key: Optional[str] = None): - super().__init__() - self.output_key = output_key - - def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: - del args, kwargs - key = self.output_key if self.output_key else self.input_key - return {key: self.flag} - - def random_dropout_input( - self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None - ) -> torch.Tensor: - del key - dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate - self.flag = torch.bernoulli((1.0 - dropout_rate) * torch.ones(1)).bool().to(device=in_tensor.device) - return in_tensor - - -class ReMapkey(BaseConditionEntry): - def __init__(self, output_key: Optional[str] = None, dtype: Optional[str] = None): - super().__init__() - self.output_key = output_key - self.dtype = { - None: None, - "float": torch.float32, - "bfloat16": torch.bfloat16, - "half": torch.float16, - "float16": torch.float16, - "int": torch.int32, - "long": torch.int64, - }[dtype] - - def forward(self, element: torch.Tensor) -> Dict[str, torch.Tensor]: - key = self.output_key if self.output_key else self.input_key - if isinstance(element, torch.Tensor): - element = element.to(dtype=self.dtype) - return {key: element} - - def details(self) -> str: - key = self.output_key if self.output_key else self.input_key - return f"Output key: {key} \n\tDtype: {self.dtype}" - - -@attrs.define(slots=False) -class FPSConfig: - """ - Remap the key from the input dictionary to the output dictionary. For `fps`. - """ - - obj: LazyDict = L(ReMapkey)(output_key="fps", dtype=None) - dropout_rate: float = 0.0 - input_key: str = "fps" - - -@attrs.define(slots=False) -class PaddingMaskConfig: - """ - Remap the key from the input dictionary to the output dictionary. For `padding_mask`. - """ - - obj: LazyDict = L(ReMapkey)(output_key="padding_mask", dtype=None) - dropout_rate: float = 0.0 - input_key: str = "padding_mask" - - -@attrs.define(slots=False) -class ImageSizeConfig: - """ - Remap the key from the input dictionary to the output dictionary. For `image_size`. - """ - - obj: LazyDict = L(ReMapkey)(output_key="image_size", dtype=None) - dropout_rate: float = 0.0 - input_key: str = "image_size" - - -@attrs.define(slots=False) -class NumFramesConfig: - """ - Remap the key from the input dictionary to the output dictionary. For `num_frames`. - """ - - obj: LazyDict = L(ReMapkey)(output_key="num_frames", dtype=None) - dropout_rate: float = 0.0 - input_key: str = "num_frames" - - -@attrs.define(slots=False) -class VideoCondBoolConfig: - obj: LazyDict = L(BooleanFlag)(output_key="video_cond_bool") - dropout_rate: float = 0.2 - input_key: str = "fps" # This is a placeholder, we never use this value - - # Config below are for long video generation only - compute_loss_for_condition_region: bool = False # Compute loss for condition region - - # How to sample condition region during training. "first_random_n" set the first n frames to be condition region, n is random, "random" set the condition region to be random, - condition_location: str = "first_random_n" - random_conditon_rate: float = 0.5 # The rate to sample the condition region randomly - first_random_n_num_condition_t_max: int = 4 # The maximum number of frames to sample as condition region, used when condition_location is "first_random_n" - first_random_n_num_condition_t_min: int = 0 # The minimum number of frames to sample as condition region, used when condition_location is "first_random_n" - - # How to dropout value of the conditional input frames - cfg_unconditional_type: str = "zero_condition_region_condition_mask" # Unconditional type. "zero_condition_region_condition_mask" set the input to zero for condition region, "noise_x_condition_region" set the input to x_t, same as the base model - - # How to corrupt the condition region - apply_corruption_to_condition_region: str = "noise_with_sigma" # Apply corruption to condition region, option: "gaussian_blur", "noise_with_sigma", "clean" (inference), "noise_with_sigma_fixed" (inference) - # Inference only option: list of sigma value for the corruption at different chunk id, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed" - apply_corruption_to_condition_region_sigma_value: list[float] = [0.001, 0.2] + [ - 0.5 - ] * 10 # Sigma value for the corruption, used when apply_corruption_to_condition_region is "noise_with_sigma_fixed" - - # Add augment_sigma condition to the network - condition_on_augment_sigma: bool = False - # The following arguments is to match with previous implementation where we use train sde to sample augment sigma (with adjust video noise turn on) - augment_sigma_sample_p_mean: float = 0.0 # Mean of the augment sigma - augment_sigma_sample_p_std: float = 1.0 # Std of the augment sigma - augment_sigma_sample_multiplier: float = 4.0 # Multipler of augment sigma - - # Add pose condition to the network - add_pose_condition: bool = False - - # Sample PPP... from IPPP... sequence - sample_tokens_start_from_p_or_i: bool = False - - # Normalize the input condition latent - normalize_condition_latent: bool = False - - -@attrs.define(slots=False) -class LatentConditionConfig: - """ - Remap the key from the input dictionary to the output dictionary. For `latent condition`. - """ - - obj: LazyDict = L(ReMapkey)(output_key="latent_condition", dtype=None) - dropout_rate: float = 0.0 - input_key: str = "latent_condition" - - -@attrs.define(slots=False) -class LatentConditionSigmaConfig: - """ - Remap the key from the input dictionary to the output dictionary. For `latent condition`. - """ - - obj: LazyDict = L(ReMapkey)(output_key="latent_condition_sigma", dtype=None) - dropout_rate: float = 0.0 - input_key: str = "latent_condition_sigma" - - -BaseVideoConditionerConfig: LazyDict = L(VideoConditioner)( - text=TextConfig(), -) - -VideoConditionerFpsSizePaddingConfig: LazyDict = L(VideoConditioner)( - text=TextConfig(), - fps=FPSConfig(), - num_frames=NumFramesConfig(), - image_size=ImageSizeConfig(), - padding_mask=PaddingMaskConfig(), -) - -VideoExtendConditionerConfig: LazyDict = L(VideoExtendConditioner)( - text=TextConfig(), - fps=FPSConfig(), - num_frames=NumFramesConfig(), - image_size=ImageSizeConfig(), - padding_mask=PaddingMaskConfig(), - video_cond_bool=VideoCondBoolConfig(), -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/data.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/data.py deleted file mode 100644 index 3d320fc92076cafbaf159921bafe5fcf11c61198..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/data.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core import parallel_state -from torch.utils.data import DataLoader, DistributedSampler - -from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS -from cosmos_transfer1.diffusion.datasets.example_transfer_dataset import ExampleTransferDataset -from cosmos_transfer1.utils.lazy_config import LazyCall as L - - -def get_sampler(dataset): - return DistributedSampler( - dataset, - num_replicas=parallel_state.get_data_parallel_world_size(), - rank=parallel_state.get_data_parallel_rank(), - shuffle=True, - seed=0, - ) - - -def get_example_transfer_dataset(hint_key, is_train=True): - dataset = L(ExampleTransferDataset)( - dataset_dir="datasets/hdvila", - num_frames=121, - resolution="720", - hint_key=hint_key, - is_train=is_train, - ) - - return L(DataLoader)( - dataset=dataset, - sampler=L(get_sampler)(dataset=dataset), - batch_size=1, - drop_last=True, - num_workers=8, # adjust as needed - prefetch_factor=2, # adjust as needed - pin_memory=True, - ) - - -# NOTE 1: For customized post train: add your dataloader registration here. -# NOTE 2: The loop below simply registers a dataset for all hint_keys in CTRL_HINT_KEYS. The actual data might not exist. -def register_data_ctrlnet(cs): - for hint_key in CTRL_HINT_KEYS: - cs.store( - group="data_train", - package="dataloader_train", - name=f"example_transfer_train_data_{hint_key}", - node=get_example_transfer_dataset(hint_key=hint_key, is_train=True), - ) - cs.store( - group="data_val", - package="dataloader_val", - name=f"example_transfer_val_data_{hint_key}", - node=get_example_transfer_dataset(hint_key=hint_key, is_train=False), - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/model.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/model.py deleted file mode 100644 index 338a9b6d73126f59f78d85efd22226a1ce38aa37..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/model.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import attrs - -from cosmos_transfer1.diffusion.config.training.ema import PowerEMAConfig -from cosmos_transfer1.diffusion.training.modules.edm_sde import EDMSDE -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - - -@attrs.define(slots=False) -class FSDPConfig: - policy: str = "block" - checkpoint: bool = False - min_num_params: int = 1024 - sharding_group_size: int = 8 - sharding_strategy: str = "full" - - -@attrs.define(slots=False) -class DefaultModelConfig: - tokenizer: LazyDict = None - conditioner: LazyDict = None - net: LazyDict = None - sigma_data: float = 0.5 - precision: str = "bfloat16" - input_data_key: str = "video" # key to fetch input data from data_batch - latent_shape: List[int] = [16, 24, 44, 80] # 24 corresponig to 136 frames - - # training related - ema: LazyDict = PowerEMAConfig - sde: LazyDict = L(EDMSDE)( - p_mean=0.0, - p_std=1.0, - sigma_max=80, - sigma_min=0.0002, - ) - camera_sample_weight: LazyDict = LazyDict( - dict( - enabled=False, - weight=5.0, - ) - ) - aesthetic_finetuning: LazyDict = LazyDict( - dict( - enabled=False, - ) - ) - loss_mask_enabled: bool = False - loss_masking: LazyDict = None - loss_add_logvar: bool = True - input_image_key: str = "images_1024" # key to fetch input image from data_batch - loss_reduce: str = "sum" - loss_scale: float = 1.0 - fsdp_enabled: bool = False - use_torch_compile: bool = False - fsdp: FSDPConfig = attrs.field(factory=FSDPConfig) - use_dummy_temporal_dim: bool = False # Whether to use dummy temporal dimension in data - adjust_video_noise: bool = False # whether or not adjust video noise accroding to the video length - context_parallel_size: int = 1 # Number of context parallel groups - - # `num_latents_to_drop` is mechanism to satisfy the CP%8==0 and (1I,N*P,1I) latents setup. - # Since our tokenizer is causal and has the `T+1` input frames setup, it makes it - # a little challenging to sample exact number of frames from file, and encode those. - # Instead, we sample as many frame from file, run the tokenizer twice, and discard the second - # chunk's P-latents, ensuring the above two requirements. By default, this flag does not have any effect. - num_latents_to_drop: int = 0 # number of latents to drop - - -@attrs.define(slots=False) -class MultiviewModelConfig(DefaultModelConfig): - n_views: int = 6 - - -@attrs.define(slots=False) -class LatentDiffusionDecoderModelConfig(DefaultModelConfig): - tokenizer_corruptor: LazyDict = None - latent_corruptor: LazyDict = None - pixel_corruptor: LazyDict = None - diffusion_decoder_cond_sigma_low: float = None - diffusion_decoder_cond_sigma_high: float = None - diffusion_decoder_corrupt_prob: float = None - condition_on_tokenizer_corruptor_token: bool = False diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/net.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/net.py deleted file mode 100644 index aef9a9f771ac8425ca03ef78ae3a4434171442e5..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/net.py +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -FADITV2Config: LazyDict = L(GeneralDIT)( - max_img_h=240, - max_img_w=240, - max_frames=128, - in_channels=16, - out_channels=16, - patch_spatial=2, - patch_temporal=1, - model_channels=4096, - block_config="FA-CA-MLP", - num_blocks=28, - num_heads=32, - concat_padding_mask=True, - pos_emb_cls="rope3d", - pos_emb_learnable=False, - pos_emb_interpolation="crop", - block_x_format="THWBD", - affline_emb_norm=True, - use_adaln_lora=True, - adaln_lora_dim=256, -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/tokenizer.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/tokenizer.py deleted file mode 100644 index 73d5c97c6c2ed7f297536c47c92710714753e680..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/base/tokenizer.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import omegaconf - -from cosmos_transfer1.diffusion.module.pretrained_vae import ( - JITVAE, - JointImageVideoSharedJITTokenizer, - VideoJITTokenizer, -) -from cosmos_transfer1.utils.lazy_config import LazyCall as L - -TOKENIZER_OPTIONS = {} - - -def tokenizer_register(key): - def decorator(func): - TOKENIZER_OPTIONS[key] = func - return func - - return decorator - - -@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") -def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: int) -> omegaconf.dictconfig.DictConfig: - assert resolution in ["720"] - - pixel_chunk_duration = chunk_duration - temporal_compression_factor = 8 - spatial_compression_factor = 8 - - return L(JointImageVideoSharedJITTokenizer)( - video_vae=L(VideoJITTokenizer)( - name="cosmos_1_0_diffusion_tokenizer", - latent_ch=16, - is_bf16=True, - pixel_chunk_duration=pixel_chunk_duration, - temporal_compression_factor=temporal_compression_factor, - spatial_compression_factor=spatial_compression_factor, - spatial_resolution=resolution, - ), - image_vae=L(JITVAE)( - name="cosmos_1_0_diffusion_tokenizer", - latent_ch=16, - is_image=False, - is_bf16=True, - ), - name="cosmos_1_0_diffusion_tokenizer", - latent_ch=16, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/config.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/config.py deleted file mode 100644 index d4281327eabd569fa3dfb2ba01216a115889660f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/config.py +++ /dev/null @@ -1,53 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, List - -import attrs - -from cosmos_transfer1.diffusion.config.base.model import DefaultModelConfig -from cosmos_transfer1.diffusion.config.registry import register_configs -from cosmos_transfer1.utils import config -from cosmos_transfer1.utils.config_helper import import_all_modules_from_package - - -@attrs.define(slots=False) -class Config(config.Config): - # default config groups that will be used unless overwritten - # see config groups in registry.py - defaults: List[Any] = attrs.field( - factory=lambda: [ - "_self_", - {"net": None}, - {"conditioner": "add_fps_image_size_padding_mask"}, - {"tokenizer": "tokenizer"}, - {"experiment": None}, - ] - ) - - -def make_config(): - c = Config( - model=DefaultModelConfig(), - ) - c.job.project = "cosmos_diffusion" - c.job.group = "inference" - - register_configs() - - # experiment config are defined in the experiment folder - # call import_all_modules_from_package to register them - import_all_modules_from_package("cosmos_transfer1.diffusion.config.inference", reload=True) - return c diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/config_train.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/config_train.py deleted file mode 100644 index 7b20b5a27c9484deef5d30553aefca2db70283ca..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/config_train.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, List - -import attrs - -from cosmos_transfer1.checkpointer.ema_fsdp_checkpointer import CheckpointConfig -from cosmos_transfer1.diffusion.config.training.registry_extra import register_configs -from cosmos_transfer1.diffusion.config.transfer.model import CtrlModelConfig -from cosmos_transfer1.diffusion.training.models.model_ctrl import VideoDiffusionModelWithCtrl -from cosmos_transfer1.utils import config -from cosmos_transfer1.utils.config_helper import import_all_modules_from_package -from cosmos_transfer1.utils.lazy_config import PLACEHOLDER -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict -from cosmos_transfer1.utils.trainer import Trainer - - -@attrs.define(slots=False) -class Config(config.Config): - # default config groups that will be used unless overwritten - # see config groups in registry.py - defaults: List[Any] = attrs.field( - factory=lambda: [ - "_self_", - {"data_train": None}, - {"data_val": None}, - {"optimizer": "fusedadamw"}, - {"scheduler": "lambdalinear"}, - {"callbacks": None}, - # - {"net": None}, - {"net_ctrl": None}, - {"hint_key": "control_input_edge"}, - {"conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, - {"pixel_corruptor": None}, - {"fsdp": None}, - {"ema": "power"}, - {"checkpoint": "local"}, - {"ckpt_klass": "multi_rank"}, - {"tokenizer": "vae1"}, - # the list is with order, we need global experiment to be the last one - {"experiment": None}, - ] - ) - model_obj: LazyDict = L(VideoDiffusionModelWithCtrl)( - config=PLACEHOLDER, - ) - checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) - - -def make_config(): - c = Config( - model=CtrlModelConfig(), - optimizer=None, - scheduler=None, - dataloader_train=None, - dataloader_val=None, - ) - - c.job.project = "cosmos_transfer1" - c.job.group = "debug" - c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" - - c.trainer.type = Trainer - # c.trainer.straggler_detection.enabled = False - c.trainer.max_iter = 400_000 - c.trainer.logging_iter = 10 - c.trainer.validation_iter = 100 - c.trainer.run_validation = False - c.trainer.callbacks = None - - register_configs() - import_all_modules_from_package("cosmos_transfer1.diffusion.config.training.experiment", reload=True) - return c diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/inference/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/inference/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/inference/cosmos-1-diffusion-control2world.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/inference/cosmos-1-diffusion-control2world.py deleted file mode 100644 index 991e07f816a1ee09e242033ae5ed2d1ae7ef1faa..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/inference/cosmos-1-diffusion-control2world.py +++ /dev/null @@ -1,401 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from hydra.core.config_store import ConfigStore - -from cosmos_transfer1.checkpoints import ( - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH, - EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, -) -from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB -from cosmos_transfer1.diffusion.model.model_ctrl import ( - VideoDiffusionModelWithCtrl, - VideoDiffusionT2VModelWithCtrl, - VideoDistillModelWithCtrl, -) -from cosmos_transfer1.diffusion.model.model_multi_camera_ctrl import MultiVideoDiffusionModelWithCtrl -from cosmos_transfer1.diffusion.networks.general_dit_multi_view import MultiViewVideoExtendGeneralDIT -from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -cs = ConfigStore.instance() - -# Base configuration for 7B model -Base_7B_Config = LazyDict( - dict( - defaults=[ - {"override /net": "faditv2_7b"}, - {"override /conditioner": "add_fps_image_size_padding_mask"}, - {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, - "_self_", - ], - model=dict( - latent_shape=[16, 16, 88, 160], - net=dict( - rope_h_extrapolation_ratio=1, - rope_w_extrapolation_ratio=1, - rope_t_extrapolation_ratio=2, - ), - ), - job=dict( - group="Control2World", - name="Base_7B_Config", - ), - ) -) - - -def make_ctrlnet_config_7b( - hint_key: str = "control_input_seg", - num_control_blocks: int = 3, -) -> LazyDict: - hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key]) - - return LazyDict( - dict( - defaults=[ - "/experiment/Base_7B_Config", - {"override /hint_key": hint_key}, - {"override /net_ctrl": "faditv2_7b"}, - {"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, - ], - job=dict( - group="CTRL_7Bv1_lvg", - name=f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}", - project="cosmos_transfer1", - ), - model=dict( - hint_mask=hint_mask, - hint_dropout_rate=0.3, - conditioner=dict(video_cond_bool=dict()), - net=L(VideoExtendGeneralDIT)( - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - net_ctrl=dict( - in_channels=17, - hint_channels=128, - num_blocks=28, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)], - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - ), - model_obj=L(VideoDiffusionModelWithCtrl)(), - ) - ) - - -def make_ctrlnet_config_7b_t2v( - hint_key: str = "control_input_seg", - num_control_blocks: int = 3, -) -> LazyDict: - hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key]) - - return LazyDict( - dict( - defaults=[ - "/experiment/Base_7B_Config", - {"override /hint_key": hint_key}, - {"override /net_ctrl": "faditv2_7b"}, - {"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, - ], - job=dict( - group="CTRL_7Bv1_t2v", - name=f"CTRL_7Bv1pt3_t2v_121frames_{hint_key}_block{num_control_blocks}", - project="cosmos_ctrlnet1", - ), - model=dict( - base_load_from=dict( - load_path=f"checkpoints/{BASE_7B_CHECKPOINT_AV_SAMPLE_PATH}", - ), - hint_mask=hint_mask, - hint_dropout_rate=0.3, - net=dict( - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - net_ctrl=dict( - in_channels=16, - hint_channels=16, - num_blocks=28, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)], - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - ), - model_obj=L(VideoDiffusionT2VModelWithCtrl)(), - ) - ) - - -def make_ctrlnet_config_7b_mv( - hint_key: str = "control_input_seg", - num_control_blocks: int = 3, - t2w: bool = True, -) -> LazyDict: - hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key]) - - return LazyDict( - dict( - defaults=[ - "/experiment/Base_7B_Config", - {"override /hint_key": hint_key}, - {"override /net_ctrl": "faditv2_7b_mv"}, - {"override /conditioner": "view_cond_ctrlnet_add_fps_image_size_padding_mask"}, - ], - job=dict( - group="CTRL_7Bv1_mv", - name=f"CTRL_7Bv1pt3_sv2mv_{'t2w' if t2w else 'v2w'}_57frames_{hint_key}_block{num_control_blocks}", - project="cosmos_ctrlnet1", - ), - model=dict( - n_views=6, - base_load_from=dict( - load_path=f"checkpoints/{BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}" - if t2w - else f"checkpoints/{BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}", - ), - hint_mask=hint_mask, - hint_dropout_rate=0.3, - conditioner=dict( - video_cond_bool=dict( - condition_location="first_cam" if t2w else "first_cam_and_first_n", - ) - ), - net=L(MultiViewVideoExtendGeneralDIT)( - n_views=6, - n_views_emb=7, - camera_condition_dim=6, - add_repeat_frame_embedding=True, - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - ), - net_ctrl=dict( - in_channels=16, - hint_channels=16, - num_blocks=28, - n_views=6, - n_views_emb=7, - camera_condition_dim=6, - add_repeat_frame_embedding=True, - is_extend_model=True, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)], - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - ), - tokenizer=dict( - video_vae=dict( - pixel_chunk_duration=57, - ) - ), - ), - model_obj=L(MultiVideoDiffusionModelWithCtrl)(), - ) - ) - - -def make_ctrlnet_config_7b_mv_waymo( - hint_key: str = "control_input_seg", - num_control_blocks: int = 3, - t2w: bool = True, -) -> LazyDict: - hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key]) - - return LazyDict( - dict( - defaults=[ - "/experiment/Base_7B_Config", - {"override /hint_key": hint_key}, - {"override /net_ctrl": "faditv2_7b_mv"}, - {"override /conditioner": "view_cond_ctrlnet_add_fps_image_size_padding_mask"}, - ], - job=dict( - group="CTRL_7Bv1_mv", - name=f"CTRL_7Bv1pt3_sv2mv_{'t2w' if t2w else 'v2w'}_57frames_{hint_key}_waymo_block{num_control_blocks}", - project="cosmos_ctrlnet1", - ), - model=dict( - n_views=5, - base_load_from=dict( - load_path=f"checkpoints/{BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}" - if t2w - else f"checkpoints/{BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH}", - ), - hint_mask=hint_mask, - hint_dropout_rate=0.15, - conditioner=dict( - video_cond_bool=dict( - condition_location="first_cam" if t2w else "first_cam_and_first_n", - cfg_unconditional_type="zero_condition_region_condition_mask", - apply_corruption_to_condition_region="noise_with_sigma", - condition_on_augment_sigma=False, - dropout_rate=0.0, - first_random_n_num_condition_t_max=0 if t2w else 2, - normalize_condition_latent=False, - augment_sigma_sample_p_mean=-3.0, - augment_sigma_sample_p_std=2.0, - augment_sigma_sample_multiplier=1.0, - ) - ), - net=L(MultiViewVideoExtendGeneralDIT)( - n_views=5, - n_views_emb=7, - camera_condition_dim=6, - add_repeat_frame_embedding=True, - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - num_blocks=28, - ), - adjust_video_noise=True, - net_ctrl=dict( - in_channels=16, - hint_channels=16, - num_blocks=28, - n_views=5, - n_views_emb=7, - camera_condition_dim=6, - add_repeat_frame_embedding=True, - is_extend_model=True, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)], - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - tokenizer=dict( - video_vae=dict( - pixel_chunk_duration=57, - ) - ), - ), - model_obj=L(MultiVideoDiffusionModelWithCtrl)(), - ) - ) - - -# Register base configs -cs.store(group="experiment", package="_global_", name=Base_7B_Config["job"]["name"], node=Base_7B_Config) -# Register all control configurations -num_control_blocks = 3 -for key in CTRL_HINT_KEYS_COMB.keys(): - # Register 7B configurations - config_7b = make_ctrlnet_config_7b(hint_key=key, num_control_blocks=num_control_blocks) - cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b) - -# Register t2v based control net -num_control_blocks = 3 -for key in ["control_input_hdmap", "control_input_lidar"]: - # Register 7B configurations - config_7b = make_ctrlnet_config_7b_t2v(hint_key=key, num_control_blocks=num_control_blocks) - cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b) - -num_control_blocks = 3 -for key in ["control_input_hdmap", "control_input_lidar"]: - for t2w in [True, False]: - # Register 7B sv2mv configurations - config_7b = make_ctrlnet_config_7b_mv(hint_key=key, num_control_blocks=num_control_blocks, t2w=t2w) - cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b) - - -# Register waymo example -num_control_blocks = 3 -for key in ["control_input_hdmap", "control_input_lidar"]: - for t2w in [True, False]: - # Register 7B sv2mv configurations - config_7b = make_ctrlnet_config_7b_mv_waymo(hint_key=key, num_control_blocks=num_control_blocks, t2w=t2w) - cs.store(group="experiment", package="_global_", name=config_7b["job"]["name"], node=config_7b) - - -def make_ctrlnet_config_7b_distilled( - hint_key: str = "control_input_edge", - num_control_blocks: int = 3, -) -> LazyDict: - hint_mask = [True] * len(CTRL_HINT_KEYS_COMB[hint_key]) - - return LazyDict( - dict( - defaults=[ - "/experiment/Base_7B_Config", - {"override /hint_key": hint_key}, - {"override /net": "faditv2_7b"}, - {"override /net_ctrl": "faditv2_7b"}, - {"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, - "_self_", - ], - job=dict( - group="DISTILL_CTRL_7Bv1", - name=f"CTRL_7Bv1pt3_lvg_fsdp_distilled_121frames_{hint_key}_block{num_control_blocks}", - project="cosmos_nano_v1", - ), - model=dict( - base_load_from=dict( - load_path=f"checkpoints/{EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH}", - ), - hint_mask=hint_mask, - hint_dropout_rate=0.0, - conditioner=dict( - video_cond_bool=dict( - condition_location="first_random_n", - cfg_unconditional_type="zero_condition_region_condition_mask", - apply_corruption_to_condition_region="noise_with_sigma_fixed", - condition_on_augment_sigma=False, - dropout_rate=0.0, - first_random_n_num_condition_t_max=2, - first_random_n_num_condition_t_min=0, - normalize_condition_latent=False, - augment_sigma_sample_p_mean=-3.0, - augment_sigma_sample_p_std=2.0, - augment_sigma_sample_multiplier=1.0, - ) - ), - net=L(VideoExtendGeneralDIT)( - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - rope_t_extrapolation_ratio=2, - ), - net_ctrl=dict( - in_channels=17, - hint_channels=128, - num_blocks=28, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(28)], - num_control_blocks=num_control_blocks, - dropout_ctrl_branch=0, - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - ), - model_obj=L(VideoDistillModelWithCtrl)(), - ) - ) - - -# Register the specific distilled configuration -distilled_config = make_ctrlnet_config_7b_distilled(hint_key="control_input_edge", num_control_blocks=3) -cs.store( - group="experiment", - package="_global_", - name="dev_v2w_ctrl_7bv1pt3_VisControlCanny_video_only_dmd2_fsdp", - node=distilled_config, -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/registry.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/registry.py deleted file mode 100644 index 5ebd589e00151d18ab1f8d97ca5216050597a28e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/registry.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from hydra.core.config_store import ConfigStore - -from cosmos_transfer1.diffusion.config.base.conditioner import ( - BaseVideoConditionerConfig, - VideoConditionerFpsSizePaddingConfig, - VideoExtendConditionerConfig, -) -from cosmos_transfer1.diffusion.config.base.net import FADITV2Config -from cosmos_transfer1.diffusion.config.base.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 - - -def register_net(cs): - cs.store( - group="net", - package="model.net", - name="faditv2_7b", - node=FADITV2Config, - ) - - -def register_conditioner(cs): - cs.store( - group="conditioner", - package="model.conditioner", - name="basic", - node=BaseVideoConditionerConfig, - ) - cs.store( - group="conditioner", - package="model.conditioner", - name="add_fps_image_size_padding_mask", - node=VideoConditionerFpsSizePaddingConfig, - ) - cs.store( - group="conditioner", - package="model.conditioner", - name="video_cond", - node=VideoExtendConditionerConfig, - ) - - -def register_tokenizer(cs): - cs.store( - group="tokenizer", - package="model.tokenizer", - name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", - node=get_cosmos_diffusion_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), - ) - - -def register_configs(): - """ - base model related registry - """ - cs = ConfigStore.instance() - - register_net(cs) - register_conditioner(cs) - register_tokenizer(cs) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/callbacks.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/callbacks.py deleted file mode 100644 index 85a0d68c1e1b9a9c911dab5a40dd0ac352b6818f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/callbacks.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.diffusion.training.callbacks.grad_clip import GradClip -from cosmos_transfer1.diffusion.training.callbacks.iter_speed import IterSpeed -from cosmos_transfer1.diffusion.training.callbacks.low_precision import LowPrecisionCallback -from cosmos_transfer1.utils.callback import ProgressBarCallback -from cosmos_transfer1.utils.lazy_config import PLACEHOLDER -from cosmos_transfer1.utils.lazy_config import LazyCall as L - -BASIC_CALLBACKS = dict( - progress_bar=L(ProgressBarCallback)(), - grad_clip=L(GradClip)(fsdp_enabled=True, model_key="model"), - low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), - # for the first 1000 iterations, log the iteration speed per iteration, after that, log every 200 iterations - iter_speed=L(IterSpeed)(every_n=200, hit_thres=1000), -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/checkpoint.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/checkpoint.py deleted file mode 100644 index 2b7c3648c17b1b54f4ccaa32c51c5d13dcf10690..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/checkpoint.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict - -from cosmos_transfer1.checkpointer.fast_tp import Checkpointer as FastTPCheckpointer -from cosmos_transfer1.checkpointer.fsdp_checkpointer import FSDPCheckpointer -from cosmos_transfer1.checkpointer.multi_rank_checkpointer import MultiRankCheckpointer -from cosmos_transfer1.checkpointer.tp_checkpointer import Checkpointer as TPCheckpointer -from cosmos_transfer1.utils.lazy_config import LazyCall as L - -MULTI_RANK_CHECKPOINTER: Dict[str, str] = L(MultiRankCheckpointer)() -FSDP_CHECKPOINTER: Dict[str, str] = L(FSDPCheckpointer)() -MODEL_PARALLEL_CHECKPOINTER: Dict[str, str] = L(TPCheckpointer)() -FAST_TP_CHECKPOINTER: Dict[str, str] = L(FastTPCheckpointer)() diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/ema.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/ema.py deleted file mode 100644 index 8faacbddf3fbe4ba9f361b585b302a4e1ffe1e89..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/ema.py +++ /dev/null @@ -1,27 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.utils.ema import EMAModelTracker, PowerEMATracker -from cosmos_transfer1.utils.lazy_config import PLACEHOLDER -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -PowerEMAConfig: LazyDict = L(PowerEMATracker.initialize_multi_rank_ema)( - model=PLACEHOLDER, enabled=True, rate=0.10, num=3 -) - -RegEMAConfig: LazyDict = L(EMAModelTracker.initialize_multi_rank_ema)( - model=PLACEHOLDER, enabled=True, rate=0.999, num=1 -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py deleted file mode 100644 index 47db45d83e75ba68a39fa88c91898e42f7c5d654..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py +++ /dev/null @@ -1,211 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script will make + register the architecture + training-related configs for all the control modalities (one config per modality). -The configs are registered under the group "experiment" and can be used in training by passing the experiment name as an argument. - -Example usage: - - [dryrun, generate and inspect EdgeControl config]: - torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain - - [real run, 8 gpu, train SegControl from scratch]: - torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_pretrain - - [real run, 8 gpu, train SegControl from released checkpoint]: - torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain -""" - -import copy -import os - -from hydra.core.config_store import ConfigStore - -from cosmos_transfer1.checkpoints import COSMOS_TRANSFER1_7B_CHECKPOINT -from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB -from cosmos_transfer1.diffusion.inference.inference_utils import default_model_names -from cosmos_transfer1.diffusion.training.models.model_ctrl import ( # this one has training support - VideoDiffusionModelWithCtrl, -) -from cosmos_transfer1.diffusion.training.networks.general_dit_video_conditioned import VideoExtendGeneralDIT -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -cs = ConfigStore.instance() - -num_frames = 121 -num_blocks = 28 -num_control_blocks = 3 - - -def make_ctrlnet_config_7b_training( - hint_key: str = "control_input_canny", num_control_blocks: int = 3, pretrain_model_path: str = "" -) -> LazyDict: - if pretrain_model_path == "": - job_name = f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}_pretrain" - job_project = "cosmos_transfer1_pretrain" - else: - job_name = f"CTRL_7Bv1pt3_lvg_tp_121frames_{hint_key}_block{num_control_blocks}_posttrain" - job_project = "cosmos_transfer1_posttrain" - - config = LazyDict( - dict( - defaults=[ - {"override /net": "faditv2_7b"}, - {"override /net_ctrl": "faditv2_7b"}, - {"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, - {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, - # - {"override /hint_key": hint_key}, - {"override /callbacks": "basic"}, - {"override /checkpoint": "local"}, - {"override /ckpt_klass": "fast_tp"}, - # - # data: register your own data at cosmos_transfer1/diffusion/config/base/data.py - {"override /data_train": f"example_transfer_train_data_{hint_key}"}, - {"override /data_val": f"example_transfer_val_data_{hint_key}"}, - "_self_", - ], - # ckpt, config yaml files etc. will be saved under checkpoints//// - job=dict( - project=job_project, - group="CTRL_7Bv1_lvg", - name=job_name, - ), - optimizer=dict( - lr=2 ** (-14.3), # ~5e-5 - weight_decay=0.1, - betas=[0.9, 0.99], - eps=1e-10, - ), - checkpoint=dict( - load_path=pretrain_model_path, # Modify load_path as needed if you do post-training (fine-tuning). If training from scratch, leave it empty. - broadcast_via_filesystem=True, - save_iter=1000, # 1000 iterations per checkpoint. Update as needed. - load_training_state=False, - strict_resume=False, # TODO (qianlim): temporary hack: We have excluded the base model ckpt from each full controlnet. The base model weights are loaded below, see 'base_load_from'. - keys_not_to_resume=[], - ), - trainer=dict( - distributed_parallelism="ddp", - logging_iter=200, # will log iter speed, loss, etc. every 200 iterations. (Will log per-iteration speed for the first 1000 iterations.) - max_iter=999_999_999, - timestamp_seed=True, - ), - model_parallel=dict( - tensor_model_parallel_size=8, - sequence_parallel=True, - ), - model=dict( - fsdp_enabled=False, - context_parallel_size=1, - loss_reduce="mean", - latent_shape=[ - 16, - (num_frames - 1) // 8 + 1, # for 121 frames, this is 16 - 88, - 160, - ], - base_load_from=dict( - load_path=os.path.join( - "checkpoints", COSMOS_TRANSFER1_7B_CHECKPOINT, "checkpoints_tp", "base_model_model_mp_*.pt" - ) - ), # modify as needed. This is the TP version of base model ckpt (that's frozen during training). - finetune_base_model=False, - hint_mask=[True] * len(CTRL_HINT_KEYS_COMB[hint_key]), - hint_dropout_rate=0.3, - conditioner=dict( - video_cond_bool=dict( - condition_location="first_random_n", - cfg_unconditional_type="zero_condition_region_condition_mask", - apply_corruption_to_condition_region="noise_with_sigma", - condition_on_augment_sigma=False, - dropout_rate=0.0, - first_random_n_num_condition_t_max=2, - normalize_condition_latent=False, - augment_sigma_sample_p_mean=-3.0, - augment_sigma_sample_p_std=2.0, - augment_sigma_sample_multiplier=1.0, - ) - ), - net=L(VideoExtendGeneralDIT)( - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - rope_h_extrapolation_ratio=1, - rope_t_extrapolation_ratio=2, - rope_w_extrapolation_ratio=1, - ), - adjust_video_noise=True, - net_ctrl=dict( - in_channels=17, - hint_channels=128, - num_blocks=num_blocks, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(num_blocks)], - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - ema=dict( - enabled=True, - ), - ), - model_obj=L(VideoDiffusionModelWithCtrl)(), - scheduler=dict( - warm_up_steps=[2500], - cycle_lengths=[10000000000000], - f_start=[1.0e-6], - f_max=[1.0], - f_min=[1.0], - ), - ) - ) - return config - - -""" -Register configurations -The loop below will register ALL experiments CTRL_7Bv1pt3_lvg_tp_121frames_control_input_{hint_key_name}_block3_{pretrain_or_posttrain} for ALL hint_key_name. -Then in training command, simply need to pass the "experiment" arg to override the configs. See the docstring at top of this script for an example. - -# NOTE: To launch real post-training, convert the checkpoints to TP checkpoints first. See scripts/convert_ckpt_fsdp_to_tp.py. -""" -for key in CTRL_HINT_KEYS_COMB.keys(): - if key in ["control_input_hdmap", "control_input_lidar"]: - continue - # Register experiments for pretraining from scratch - config = make_ctrlnet_config_7b_training( - hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path="" - ) - cs.store( - group="experiment", - package="_global_", - name=config["job"]["name"], - node=config, - ) - # Register experiments for post-training from TP checkpoints. - hint_key_short = key.replace("control_input_", "") # "control_input_vis" -> "vis" - pretrain_ckpt_path = default_model_names[hint_key_short] - # note: The TP ckpt path are specified as .pt to the script, but actually the _model_mp_*.pt files will be loaded. - tp_ckpt_path = os.path.join( - "checkpoints", os.path.dirname(pretrain_ckpt_path), "checkpoints_tp", os.path.basename(pretrain_ckpt_path) - ) - config = make_ctrlnet_config_7b_training( - hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path=tp_ckpt_path - ) - cs.store( - group="experiment", - package="_global_", - name=config["job"]["name"], - node=config, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av.py deleted file mode 100644 index fb36f3617a152179d26e4959c5f5820d00952a2b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av.py +++ /dev/null @@ -1,238 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - - [real run, 8 gpu] torchrun --nproc_per_node=8 -m projects.edify_image.v4.train --dryrun --config=projects/edify_video/v4/config/ctrl/config.py -- experiment=CTRL_tp_121frames_control_input_bbox_image_block3 - - [debug small model, 1 gpu] torchrun --nproc_per_node=8 -m projects.edify_image.v4.train --config=projects/edify_video/v4/config/ctrl/config.py -- experiment=CTRL_tp_121frames_control_input_bbox_image_block3 model.net.num_blocks=1 model.context_parallel_size=1 checkpoint.load_path="" job.group=debug trainer.logging_iter=5 -""" - -import copy -import os - -from hydra.core.config_store import ConfigStore -from megatron.core import parallel_state -from torch.utils.data import DataLoader, DistributedSampler - -from cosmos_transfer1.checkpoints import COSMOS_TRANSFER1_7B_CHECKPOINT, COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT -from cosmos_transfer1.diffusion.config.base.data import get_sampler -from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB -from cosmos_transfer1.diffusion.datasets.example_transfer_dataset import AVTransferDataset, ExampleTransferDataset -from cosmos_transfer1.diffusion.inference.inference_utils import default_model_names -from cosmos_transfer1.diffusion.training.models.model_ctrl import ( # this one has training support - ShortVideoDiffusionModelWithCtrl, - VideoDiffusionModelWithCtrl, -) -from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT -from cosmos_transfer1.diffusion.training.networks.general_dit_video_conditioned import VideoExtendGeneralDIT -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -cs = ConfigStore.instance() - -num_blocks = 28 -num_control_blocks = 3 -ckpt_root = "checkpoints/" -data_root = "datasets/waymo_transfer1/" - - -def make_ctrlnet_config( - hint_key: str = "control_input_hdmap", - num_control_blocks: int = 3, - pretrain_model_path: str = "", - t2w: bool = True, - num_frames=121, -) -> LazyDict: - if pretrain_model_path == "": - if t2w: - job_name = f"CTRL_7Bv1pt3_t2w_{num_frames}frames_{hint_key}_block{num_control_blocks}_pretrain" - job_project = "cosmos_transfer1_pretrain" - else: - job_name = f"CTRL_7Bv1pt3_lvg_{num_frames}frames_{hint_key}_block{num_control_blocks}_pretrain" - job_project = "cosmos_transfer1_pretrain" - else: - if t2w: - job_name = f"CTRL_7Bv1pt3_t2w_{num_frames}frames_{hint_key}_block{num_control_blocks}_posttrain" - job_project = "cosmos_transfer1_posttrain" - else: - job_name = f"CTRL_7Bv1pt3_lvg_{num_frames}frames_{hint_key}_block{num_control_blocks}_posttrain" - job_project = "cosmos_transfer1_posttrain" - example_multiview_dataset_waymo = L(AVTransferDataset)( - dataset_dir=data_root, - num_frames=num_frames, - hint_key=hint_key, - resolution="720", - view_keys=["front"], - sample_n_views=-1, - load_mv_emb=False, - ) - - ctrl_config = LazyDict( - dict( - defaults=[ - {"override /net": "faditv2_7b"}, - {"override /net_ctrl": "faditv2_7b"}, - {"override /conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, - {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, - {"override /hint_key": hint_key}, - {"override /callbacks": "basic"}, - {"override /checkpoint": "local"}, - {"override /ckpt_klass": "fast_tp"}, - "_self_", - ], - job=dict(group="CTRL_7Bv1_sampleAV", project=job_project, name=job_name), - optimizer=dict( - lr=2 ** (-14.3), # ~5e-5 - weight_decay=0.1, - betas=[0.9, 0.99], - eps=1e-10, - ), - checkpoint=dict( - load_path=pretrain_model_path, - # Modify load_path as needed if you do post-training (fine-tuning). If training from scratch, leave it empty. - broadcast_via_filesystem=True, - save_iter=1000, - load_training_state=False, - strict_resume=False, - keys_not_to_resume=[], - ), - trainer=dict( - distributed_parallelism="ddp", - logging_iter=200, - max_iter=999_999_999, - callbacks=dict( - iter_speed=dict(hit_thres=5), - ), - timestamp_seed=True, # important for dataver dataloader!!! - ), - model_parallel=dict( - tensor_model_parallel_size=8, - sequence_parallel=True, - ), - model=dict( - fsdp_enabled=False, - context_parallel_size=1, - loss_reduce="mean", - latent_shape=[ - 16, - (num_frames - 1) // 8 + 1, - 88, - 160, - ], - base_load_from=dict( - load_path=os.path.join( - ckpt_root, - COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT, - "checkpoints_tp", - "base_model_model_mp_*.pt", - ) - ), - finetune_base_model=False, - hint_mask=[True], - hint_dropout_rate=0.15, - conditioner=dict( - video_cond_bool=dict( - condition_location="first_random_n", - cfg_unconditional_type="zero_condition_region_condition_mask", - apply_corruption_to_condition_region="noise_with_sigma", - condition_on_augment_sigma=False, - dropout_rate=0.0, - first_random_n_num_condition_t_max=0 if t2w else 2, - normalize_condition_latent=False, - augment_sigma_sample_p_mean=-3.0, - augment_sigma_sample_p_std=2.0, - augment_sigma_sample_multiplier=1.0, - ) - ), - net=dict( - in_channels=16, - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - adjust_video_noise=True, - net_ctrl=dict( - in_channels=16, - hint_channels=16, - num_blocks=num_blocks, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(num_blocks)], - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - tokenizer=dict( - pixel_chunk_duration=num_frames, - ), - ), - model_obj=L(ShortVideoDiffusionModelWithCtrl)(), - dataloader_train=L(DataLoader)( - dataset=example_multiview_dataset_waymo, - sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), - batch_size=1, - drop_last=True, - pin_memory=True, - num_workers=8, - ), - dataloader_val=L(DataLoader)( - dataset=example_multiview_dataset_waymo, - sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), - batch_size=1, - drop_last=True, - pin_memory=True, - num_workers=8, - ), - ) - ) - return ctrl_config - - -all_hint_key = [ - "control_input_hdmap", - "control_input_lidar", -] - -for key in all_hint_key: - for num_frames in [57, 121]: - # Register experiments for pretraining from scratch - t2w_config = make_ctrlnet_config( - hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path="", t2w=True, num_frames=num_frames - ) - cs.store( - group="experiment", - package="_global_", - name=t2w_config["job"]["name"], - node=t2w_config, - ) - - # Register experiments for post-training from TP checkpoints. - hint_key_short = key.replace("control_input_", "") # "control_input_vis" -> "vis" - pretrain_ckpt_path = default_model_names[hint_key_short] - # note: The TP ckpt path are specified as .pt to the script, but actually the _model_mp_*.pt files will be loaded. - tp_ckpt_path = os.path.join( - ckpt_root, os.path.dirname(pretrain_ckpt_path), "checkpoints_tp", os.path.basename(pretrain_ckpt_path) - ) - config = make_ctrlnet_config( - hint_key=key, - num_control_blocks=num_control_blocks, - pretrain_model_path=tp_ckpt_path, - t2w=True, - num_frames=num_frames, - ) - cs.store( - group="experiment", - package="_global_", - name=config["job"]["name"], - node=config, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av_mv.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av_mv.py deleted file mode 100644 index 29449cf44d67dcfaea81d0d89a08a29578795be0..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av_mv.py +++ /dev/null @@ -1,310 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - - [real run, 8 gpu] torchrun --nproc_per_node=8 -m projects.edify_image.v4.train --dryrun --config=projects/edify_video/v4/config/ctrl/config.py -- experiment=CTRL_tp_121frames_control_input_bbox_image_block3 - - [debug small model, 1 gpu] torchrun --nproc_per_node=8 -m projects.edify_image.v4.train --config=projects/edify_video/v4/config/ctrl/config.py -- experiment=CTRL_tp_121frames_control_input_bbox_image_block3 model.net.num_blocks=1 model.context_parallel_size=1 checkpoint.load_path="" job.group=debug trainer.logging_iter=5 -""" - -import copy -import os - -from hydra.core.config_store import ConfigStore -from megatron.core import parallel_state -from torch.utils.data import DataLoader, DistributedSampler - -from cosmos_transfer1.checkpoints import ( - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, -) -from cosmos_transfer1.diffusion.config.base.data import get_sampler -from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_HINT_KEYS_COMB -from cosmos_transfer1.diffusion.datasets.example_transfer_dataset import AVTransferDataset -from cosmos_transfer1.diffusion.training.models.extend_model_multiview_ctrl import ( - FSDPMultiVideoDiffusionModelWithCtrl, - MultiVideoDiffusionModelWithCtrl, -) -from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT -from cosmos_transfer1.diffusion.training.networks.general_dit_multi_camera import VideoExtendGeneralDIT -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -cs = ConfigStore.instance() - -num_blocks = 28 -num_frames = 57 -num_control_blocks = 3 -ckpt_root = "checkpoints/" -data_root = "datasets/waymo_transfer/" - -t2w_mv_model_names = { - "hdmap": SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "lidar": SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, -} -v2w_mv_model_names = { - "hdmap": SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "lidar": SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, -} - - -def make_ctrlnet_config( - hint_key: str = "control_input_hdmap", - num_control_blocks: int = 3, - pretrain_model_path: str = "", - t2w: bool = True, - num_frames=121, -) -> LazyDict: - if pretrain_model_path == "": - if t2w: - job_name = f"CTRL_7Bv1pt3_t2w_sv2mv_{num_frames}frames_{hint_key}_block{num_control_blocks}_pretrain" - job_project = "cosmos_transfer1_pretrain" - else: - job_name = f"CTRL_7Bv1pt3_v2w_sv2mv_{num_frames}frames_{hint_key}_block{num_control_blocks}_pretrain" - job_project = "cosmos_transfer1_pretrain" - else: - if t2w: - job_name = f"CTRL_7Bv1pt3_t2w_sv2mv_{num_frames}frames_{hint_key}_block{num_control_blocks}_posttrain" - job_project = "cosmos_transfer1_posttrain" - else: - job_name = f"CTRL_7Bv1pt3_v2w_sv2mv_{num_frames}frames_{hint_key}_block{num_control_blocks}_posttrain" - job_project = "cosmos_transfer1_posttrain" - if t2w: - base_load_path = os.path.join( - ckpt_root, - os.path.dirname( - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH if t2w else BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH - ), - "checkpoints_tp", - "t2w_base_model_model_mp_*.pt", - ) - else: - base_load_path = os.path.join( - ckpt_root, - os.path.dirname( - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH if t2w else BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH - ), - "checkpoints_tp", - "v2w_base_model_model_mp_*.pt", - ) - example_multiview_dataset_waymo = L(AVTransferDataset)( - dataset_dir=data_root, - num_frames=num_frames, - hint_key=hint_key, - resolution="720", - view_keys=[ - "pinhole_front", - "pinhole_front_left", - "pinhole_front_right", - "pinhole_side_left", - "pinhole_side_right", - ], - caption_view_idx_map={0: 0, 1: 1, 2: 2, 3: 4, 4: 5}, - sample_n_views=3, - load_mv_emb=False, - is_train=True, - ) - - ctrl_config = LazyDict( - dict( - defaults=[ - {"override /net": "faditv2_7b"}, - {"override /net_ctrl": "faditv2_sv2mv"}, - {"override /conditioner": "view_cond_ctrlnet_add_fps_image_size_padding_mask"}, - {"override /tokenizer": "cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624"}, - {"override /hint_key": hint_key}, - {"override /callbacks": "basic"}, - {"override /checkpoint": "local"}, - {"override /ckpt_klass": "fast_tp"}, - "_self_", - ], - job=dict(group="CTRL_7Bv1_sampleAV", project=job_project, name=job_name), - optimizer=dict( - lr=2 ** (-14.3), # ~5e-5 - weight_decay=0.1, - betas=[0.9, 0.99], - eps=1e-10, - ), - checkpoint=dict( - load_path=pretrain_model_path, - # Modify load_path as needed if you do post-training (fine-tuning). If training from scratch, leave it empty. - broadcast_via_filesystem=True, - save_iter=1000, - load_training_state=False, - strict_resume=False, - keys_not_to_resume=[], - ), - trainer=dict( - distributed_parallelism="ddp", - logging_iter=200, - max_iter=999_999_999, - callbacks=dict( - iter_speed=dict(hit_thres=5), - ), - timestamp_seed=True, # important for dataver dataloader!!! - ), - model_parallel=dict( - tensor_model_parallel_size=8, - sequence_parallel=True, - ), - model=dict( - fsdp_enabled=False, - n_views=3, - context_parallel_size=1, - loss_reduce="mean", - latent_shape=[ - 16, - (num_frames - 1) // 8 + 1, - 88, - 160, - ], - base_load_from=dict(load_path=base_load_path), - finetune_base_model=False, - hint_mask=[True], - hint_dropout_rate=0.15, - conditioner=dict( - video_cond_bool=dict( - condition_location="first_cam" if t2w else "first_cam_and_random_n", - cfg_unconditional_type="zero_condition_region_condition_mask", - apply_corruption_to_condition_region="noise_with_sigma", - condition_on_augment_sigma=False, - dropout_rate=0.0, - first_random_n_num_condition_t_max=0 if t2w else 2, - normalize_condition_latent=False, - augment_sigma_sample_p_mean=-3.0, - augment_sigma_sample_p_std=2.0, - augment_sigma_sample_multiplier=1.0, - ) - ), - net=L(VideoExtendGeneralDIT)( - in_channels=17, - n_views=3, - n_views_emb=7, - view_condition_dim=6, - add_repeat_frame_embedding=True, - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - num_blocks=num_blocks, - ), - adjust_video_noise=True, - net_ctrl=dict( - in_channels=16, # + 1 for cond_mask, +1 for padding mask, +6 for cam - hint_channels=16, - num_blocks=num_blocks, - n_views=3, - n_views_emb=7, - view_condition_dim=6, - add_repeat_frame_embedding=True, - is_extend_model=True, - layer_mask=[True if (i >= num_control_blocks) else False for i in range(num_blocks)], - extra_per_block_abs_pos_emb=True, - pos_emb_learnable=True, - extra_per_block_abs_pos_emb_type="learnable", - ), - tokenizer=dict( - pixel_chunk_duration=num_frames, - ), - ), - model_obj=L(MultiVideoDiffusionModelWithCtrl)(), - dataloader_train=L(DataLoader)( - dataset=example_multiview_dataset_waymo, - sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), - batch_size=1, - drop_last=True, - pin_memory=True, - num_workers=8, - ), - dataloader_val=L(DataLoader)( - dataset=example_multiview_dataset_waymo, - sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), - batch_size=1, - drop_last=True, - pin_memory=True, - num_workers=8, - ), - ) - ) - return ctrl_config - - -all_hint_key = [ - "control_input_hdmap", - "control_input_lidar", -] - -for key in all_hint_key: - # Register experiments for pretraining from scratch - t2w_config = make_ctrlnet_config( - hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path="", t2w=True, num_frames=num_frames - ) - v2w_config = make_ctrlnet_config( - hint_key=key, num_control_blocks=num_control_blocks, pretrain_model_path="", t2w=False, num_frames=num_frames - ) - - cs.store( - group="experiment", - package="_global_", - name=t2w_config["job"]["name"], - node=t2w_config, - ) - cs.store( - group="experiment", - package="_global_", - name=v2w_config["job"]["name"], - node=v2w_config, - ) - # Register experiments for post-training from TP checkpoints. - hint_key_short = key.replace("control_input_", "") # "control_input_vis" -> "vis" - t2w_pretrain_ckpt_path = t2w_mv_model_names[hint_key_short] - v2w_pretrain_ckpt_path = v2w_mv_model_names[hint_key_short] - # note: The TP ckpt path are specified as .pt to the script, but actually the _model_mp_*.pt files will be loaded. - t2w_tp_ckpt_path = os.path.join( - ckpt_root, os.path.dirname(t2w_pretrain_ckpt_path), "checkpoints_tp", os.path.basename(t2w_pretrain_ckpt_path) - ) - v2w_tp_ckpt_path = os.path.join( - ckpt_root, os.path.dirname(v2w_pretrain_ckpt_path), "checkpoints_tp", os.path.basename(v2w_pretrain_ckpt_path) - ) - # tp_ckpt_path = os.path.join(ckpt_root, SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH_dbg) - t2w_posttrain_config = make_ctrlnet_config( - hint_key=key, - num_control_blocks=num_control_blocks, - pretrain_model_path=t2w_tp_ckpt_path, - t2w=True, - num_frames=num_frames, - ) - v2w_posttrain_config = make_ctrlnet_config( - hint_key=key, - num_control_blocks=num_control_blocks, - pretrain_model_path=v2w_tp_ckpt_path, - t2w=False, - num_frames=num_frames, - ) - cs.store( - group="experiment", - package="_global_", - name=t2w_posttrain_config["job"]["name"], - node=t2w_posttrain_config, - ) - cs.store( - group="experiment", - package="_global_", - name=v2w_posttrain_config["job"]["name"], - node=v2w_posttrain_config, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/optim.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/optim.py deleted file mode 100644 index 55558950b3b609e49b80f1a7a063143efec15e9b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/optim.py +++ /dev/null @@ -1,40 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.diffusion.training.functional.lr_scheduler import LambdaLinearScheduler -from cosmos_transfer1.diffusion.training.utils.optim_instantiate import get_base_optimizer -from cosmos_transfer1.utils.lazy_config import PLACEHOLDER -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -FusedAdamWConfig: LazyDict = L(get_base_optimizer)( - model=PLACEHOLDER, - lr=1e-4, - weight_decay=0.3, - betas=[0.9, 0.999], - optim_type="fusedadam", - eps=1e-8, - sharding=False, - master_weights=True, - capturable=True, -) - -LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)( - warm_up_steps=[1000], - cycle_lengths=[10000000000000], - f_start=[1.0e-6], - f_max=[1.0], - f_min=[1.0], -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/registry.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/registry.py deleted file mode 100644 index 930149a067cec24339e56adc6f678d0c3feb5f1f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/registry.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Core training related registry. -""" - -from hydra.core.config_store import ConfigStore - -from cosmos_transfer1.checkpointer.ema_fsdp_checkpointer import CheckpointConfig -from cosmos_transfer1.diffusion.config.training.callbacks import BASIC_CALLBACKS -from cosmos_transfer1.diffusion.config.training.checkpoint import ( - FAST_TP_CHECKPOINTER, - FSDP_CHECKPOINTER, - MODEL_PARALLEL_CHECKPOINTER, - MULTI_RANK_CHECKPOINTER, -) -from cosmos_transfer1.diffusion.config.training.ema import PowerEMAConfig -from cosmos_transfer1.diffusion.config.training.optim import FusedAdamWConfig, LambdaLinearSchedulerConfig - - -def register_ema(cs): - cs.store(group="ema", package="model.ema", name="power", node=PowerEMAConfig) - - -def register_optimizer(cs): - cs.store(group="optimizer", package="optimizer", name="fusedadamw", node=FusedAdamWConfig) - - -def register_scheduler(cs): - cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearSchedulerConfig) - - -def register_callbacks(cs): - cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) - - -def register_checkpoint_credential(cs): - CHECKPOINT_LOCAL = CheckpointConfig( - save_iter=1000, - load_path="", - load_training_state=False, - strict_resume=True, - ) - - cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) - - -def register_checkpointer(cs): - cs.store(group="ckpt_klass", package="checkpoint.type", name="fsdp", node=FSDP_CHECKPOINTER) - cs.store(group="ckpt_klass", package="checkpoint.type", name="multi_rank", node=MULTI_RANK_CHECKPOINTER) - cs.store(group="ckpt_klass", package="checkpoint.type", name="tp", node=MODEL_PARALLEL_CHECKPOINTER) - cs.store(group="ckpt_klass", package="checkpoint.type", name="fast_tp", node=FAST_TP_CHECKPOINTER) - - -def register_configs(): - cs = ConfigStore.instance() - - register_optimizer(cs) - register_scheduler(cs) - register_ema(cs) - register_checkpoint_credential(cs) - register_checkpointer(cs) - register_callbacks(cs) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/registry_extra.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/registry_extra.py deleted file mode 100644 index 00f31d92f2266cd5d8cd4142f3426baa50a1b8ae..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/registry_extra.py +++ /dev/null @@ -1,146 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Registry for training experiments, callbacks and data. -""" - -import copy - -from hydra.core.config_store import ConfigStore - -import cosmos_transfer1.diffusion.config.training.registry as base_training_registry -from cosmos_transfer1.diffusion.config.base.data import register_data_ctrlnet -from cosmos_transfer1.diffusion.config.registry import register_conditioner -from cosmos_transfer1.diffusion.config.training.tokenizer import get_cosmos_diffusion_tokenizer_comp8x8x8 -from cosmos_transfer1.diffusion.config.transfer.conditioner import ( - CTRL_HINT_KEYS, - BaseVideoConditionerWithCtrlConfig, - VideoConditionerFpsSizePaddingWithCtrlConfig, - ViewConditionedVideoConditionerFpsSizePaddingWithCtrlConfig, -) -from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT -from cosmos_transfer1.diffusion.training.networks.general_dit_ctrl_enc import GeneralDITEncoder -from cosmos_transfer1.diffusion.training.networks.general_dit_ctrl_enc_multicamera import GeneralDITMulticamEncoder - -# from cosmos_transfer1.diffusion.config.registry import register_tokenizer -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -FADITV2ConfigTrain: LazyDict = L(GeneralDIT)( - max_img_h=240, - max_img_w=240, - max_frames=128, - in_channels=16, - out_channels=16, - patch_spatial=2, - patch_temporal=1, - model_channels=4096, - block_config="FA-CA-MLP", - num_blocks=28, - num_heads=32, - concat_padding_mask=True, - pos_emb_cls="rope3d", - pos_emb_learnable=False, - pos_emb_interpolation="crop", - block_x_format="THWBD", - additional_timestamp_channels=None, - affline_emb_norm=True, - use_adaln_lora=True, - adaln_lora_dim=256, - legacy_patch_emb=False, -) - -num_blocks = FADITV2ConfigTrain["num_blocks"] -FADITV2EncoderConfigTrain = copy.deepcopy(FADITV2ConfigTrain) -FADITV2EncoderConfigTrain["_target_"] = GeneralDITEncoder -FADITV2EncoderConfigTrain["layer_mask"] = [True if i > num_blocks // 2 else False for i in range(num_blocks)] - -num_blocks = FADITV2ConfigTrain["num_blocks"] -FADITV2MultiCamEncoderConfig = copy.deepcopy(FADITV2ConfigTrain) -FADITV2MultiCamEncoderConfig["_target_"] = GeneralDITMulticamEncoder -FADITV2MultiCamEncoderConfig["layer_mask"] = [True if i > num_blocks // 2 else False for i in range(num_blocks)] - - -def register_net_train(cs): - cs.store( - group="net", - package="model.net", - name="faditv2_7b", - node=FADITV2ConfigTrain, - ) - cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_7b", node=FADITV2EncoderConfigTrain) - cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_sv2mv", node=FADITV2MultiCamEncoderConfig) - - -def register_conditioner_ctrlnet(cs): - cs.store( - group="conditioner", - package="model.conditioner", - name="ctrlnet", - node=BaseVideoConditionerWithCtrlConfig, - ) - cs.store( - group="conditioner", - package="model.conditioner", - name="ctrlnet_add_fps_image_size_padding_mask", - node=VideoConditionerFpsSizePaddingWithCtrlConfig, - ) - cs.store( - group="conditioner", - package="model.conditioner", - name="view_cond_ctrlnet_add_fps_image_size_padding_mask", - node=ViewConditionedVideoConditionerFpsSizePaddingWithCtrlConfig, - ) - - -def register_tokenizer(cs): - cs.store( - group="tokenizer", - package="model.tokenizer", - name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624", - node=get_cosmos_diffusion_tokenizer_comp8x8x8(resolution="720", chunk_duration=121), - ) - - -def register_configs(): - cs = ConfigStore.instance() - - # register all the basic configs: net, conditioner, tokenizer. - register_net_train(cs) - register_conditioner(cs) - register_conditioner_ctrlnet(cs) - register_tokenizer(cs) - - # register training configs: optimizer, scheduler, callbacks, etc. - base_training_registry.register_configs() - - # register data, experiment, callbacks - register_data_ctrlnet(cs) - - # register hint keys - for hint_key in CTRL_HINT_KEYS: - cs.store( - group="hint_key", - package="model", - name=hint_key, - node=dict(hint_key=dict(hint_key=hint_key, grayscale=False)), - ) - cs.store( - group="hint_key", - package="model", - name=f"{hint_key}_grayscale", - node=dict(hint_key=dict(hint_key=hint_key, grayscale=True)), - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/tokenizer.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/tokenizer.py deleted file mode 100644 index 0c4b2da9c5b446271c3b14aa3ceac3f14db74513..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/training/tokenizer.py +++ /dev/null @@ -1,81 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import omegaconf - -from cosmos_transfer1.diffusion.training.modules.pretrained_vae import ( - JITVAE, - JointImageVideoSharedJITTokenizer, - VideoJITTokenizer, -) -from cosmos_transfer1.utils.lazy_config import LazyCall as L - -TOKENIZER_OPTIONS = {} - - -def tokenizer_register(key): - def decorator(func): - TOKENIZER_OPTIONS[key] = func - return func - - return decorator - - -@tokenizer_register("cosmos_diffusion_tokenizer_comp8x8x8") -def get_cosmos_diffusion_tokenizer_comp8x8x8(resolution: str, chunk_duration: int) -> omegaconf.dictconfig.DictConfig: - assert resolution in ["720"] - - pixel_chunk_duration = chunk_duration - temporal_compression_factor = 8 - spatial_compression_factor = 8 - - return L(VideoJITTokenizer)( - name="cosmos_1_0_diffusion_tokenizer", - enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", - dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", - mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", - latent_ch=16, - is_bf16=True, - pixel_chunk_duration=pixel_chunk_duration, - temporal_compression_factor=temporal_compression_factor, - spatial_compression_factor=spatial_compression_factor, - spatial_resolution=resolution, - ) - - # return L(JointImageVideoSharedJITTokenizer)( - # video_vae=L(VideoJITTokenizer)( - # name="cosmos_1_0_diffusion_tokenizer", - # enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", - # dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", - # mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", - # latent_ch=16, - # is_bf16=True, - # pixel_chunk_duration=pixel_chunk_duration, - # temporal_compression_factor=temporal_compression_factor, - # spatial_compression_factor=spatial_compression_factor, - # spatial_resolution=resolution, - # ), - # image_vae=L(JITVAE)( - # name="cosmos_1_0_diffusion_tokenizer", - # enc_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit", - # dec_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit", - # mean_std_fp="checkpoints/nvidia/Cosmos-Tokenize1-CV8x8x8-720p/mean_std.pt", - # latent_ch=16, - # is_image=False, - # is_bf16=True, - # ), - # name="cosmos_1_0_diffusion_tokenizer", - # latent_ch=16, - # ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/augmentors.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/augmentors.py deleted file mode 100644 index 6d915605e0742856997b86028b3182c7eb41d907..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/augmentors.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.diffusion.config.transfer.blurs import ( - BilateralFilterConfig, - BlurAugmentorConfig, - BlurCombinationConfig, -) -from cosmos_transfer1.utils.lazy_config import LazyCall as L - -# predefined BilateralFilterConfig with different strength level -NoFilterConfig = L(BilateralFilterConfig)(use_random=False, d=1, sigma_color=1, sigma_space=1, iter=1) - -LowBilateralFilterConfig = L(BilateralFilterConfig)(use_random=False, d=15, sigma_color=100, sigma_space=50, iter=1) - -MediumBilateralFilterConfig = L(BilateralFilterConfig)(use_random=False, d=30, sigma_color=150, sigma_space=100, iter=1) - -HighBilateralFilterConfig = L(BilateralFilterConfig)(use_random=False, d=50, sigma_color=300, sigma_space=150, iter=1) - -BilateralOnlyBlurAugmentorConfig = {} -for strength, blur_config in zip( - ["none", "very_low", "low", "medium", "high", "very_high"], - [ - NoFilterConfig, - LowBilateralFilterConfig, - LowBilateralFilterConfig, - MediumBilateralFilterConfig, - HighBilateralFilterConfig, - HighBilateralFilterConfig, - ], -): - BlurConfig = L(BlurCombinationConfig)( - blur_types=["bilateral"], - probability=1.0, - bilateral_filter=blur_config, - ) - downscale_factor = { - "none": 1, - "very_low": 1, - "low": 4, - "medium": 2, - "high": 1, - "very_high": 4, - } - BilateralOnlyBlurAugmentorConfig[strength] = L(BlurAugmentorConfig)( - blur_combinations=[BlurConfig], - downscale_factor=[downscale_factor[strength]], - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/blurs.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/blurs.py deleted file mode 100644 index b8f5b7f3f30eae679b84d387df88f3cc161c3fdf..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/blurs.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional - -import attrs - -from cosmos_transfer1.utils.lazy_config import LazyCall as L - - -@attrs.define -class GaussianBlurConfig: - """Configuration for Gaussian blur""" - - use_random: bool = False - # if use_random is False, then optionally define the param values - ksize: int = 25 - sigmaX: float = 12.5 - - # if use_random is True, then optionally define the range - ksize_min: int = 21 - ksize_max: int = 29 - sigmaX_min: float = 10.5 - sigmaX_max: float = 14.5 - - -LowGaussianBlurConfig = L(GaussianBlurConfig)(ksize=21, sigmaX=10.5) - - -@attrs.define -class GuidedFilterConfig: - """Configuration for Guided filter""" - - use_random: bool = False - # if use_random is False, then optionally define the param values - radius: int = 45 - eps: float = 0.15 - scale: int = 10 - - # if use_random is True, then optionally define the range - radius_min: int = 41 - radius_max: int = 49 - eps_min: float = 0.1 - eps_max: float = 0.2 - scale_min: int = 3 - scale_max: int = 18 - - -@attrs.define -class BilateralFilterConfig: - """Configuration for Bilateral filter""" - - use_random: bool = False - # if use_random is False, then optionally define the param values - d: int = 30 - sigma_color: int = 150 - sigma_space: int = 100 - iter: int = 1 - - # if use_random is True, then optionally define the range - d_min: int = 15 - d_max: int = 50 - sigma_color_min: int = 100 - sigma_color_max: int = 300 - sigma_space_min: int = 50 - sigma_space_max: int = 150 - iter_min: int = 1 - iter_max: int = 4 - - -@attrs.define -class MedianBlurConfig: - """Configuration for Median blur""" - - use_random: bool = False - # if use_random is False, then optionally define the param values - ksize: int = 11 - - # if use_random is True, then optionally define the range - ksize_min: int = 9 - ksize_max: int = 15 - - -@attrs.define -class LaplacianOfGaussianConfig: - """Configuration for LoG filter""" - - use_random: bool = False - # if use_random is False, then optionally define the param values - ksize: int = 5 - sigma: float = 1.4 - binarize: bool = False - threshold: float = 0.0 - - # if use_random is True, then optionally define the range - ksize_min: int = 3 - ksize_max: int = 7 - sigma_min: float = 0.5 - sigma_max: float = 3.0 - threshold_min: float = 10.0 - threshold_max: float = 30.0 - - -@attrs.define -class AnisotropicDiffusionConfig: - """Configuration for Anisotropic Diffusion""" - - use_random: bool = False - alpha: float = 0.25 - K: float = 0.15 - niters: int = 12 - - # if use_random is True, then optionally define the range - alpha_min: float = 0.2 - alpha_max: float = 0.3 - K_min: float = 0.1 - K_max: float = 0.2 - niters_min: int = 10 - niters_max: int = 14 - - -@attrs.define -class BlurCombinationConfig: - """Configuration for a combination of blurs with associated probability""" - - # list of choices are: ["gaussian", "guided", "bilateral", "median", "log", "anisotropic"] - # the corresponding config must be defined for each item in this blur_types list - blur_types: List[str] - probability: float - gaussian_blur: Optional[GaussianBlurConfig] = None - guided_filter: Optional[GuidedFilterConfig] = None - bilateral_filter: Optional[BilateralFilterConfig] = None - median_blur: Optional[MedianBlurConfig] = None - log: Optional[LaplacianOfGaussianConfig] = None - anisotropic_diffusion: Optional[AnisotropicDiffusionConfig] = None - - -@attrs.define -class BlurAugmentorConfig: - """Configuration for blur augmentation with multiple combinations""" - - # probabilities from the list of combinations should add up to 1.0 - blur_combinations: List[BlurCombinationConfig] = [] - downscale_factor: List[int] = [1] - - -# random blur for training the VisControl -random_blur_config = BlurAugmentorConfig( - downscale_factor=list(range(1, 5)), - blur_combinations=[ - BlurCombinationConfig( - blur_types=["bilateral"], - probability=0.5, - bilateral_filter=BilateralFilterConfig(use_random=True), - ), - BlurCombinationConfig( - blur_types=["gaussian"], - probability=0.3, - gaussian_blur=GaussianBlurConfig(use_random=True), - ), - BlurCombinationConfig( - blur_types=["bilateral", "gaussian"], - probability=0.2, - bilateral_filter=BilateralFilterConfig(use_random=True), - gaussian_blur=GaussianBlurConfig(use_random=True), - ), - ], -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/conditioner.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/conditioner.py deleted file mode 100644 index 51002db62b86cc4a24d2958daa30662f6a56218d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/conditioner.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import attrs - -from cosmos_transfer1.diffusion.conditioner import ( - FrameRepeatAttr, - VideoConditionerWithCtrl, - ViewConditionedVideoConditionerWithCtrl, -) -from cosmos_transfer1.diffusion.config.base.conditioner import ( - FPSConfig, - ImageSizeConfig, - NumFramesConfig, - PaddingMaskConfig, - TextConfig, - VideoCondBoolConfig, -) -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import ( - AddControlInput, - AddControlInputDepth, - AddControlInputEdge, - AddControlInputHDMAP, - AddControlInputKeypoint, - AddControlInputLIDAR, - AddControlInputSeg, - AddControlInputUpscale, -) -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - -CTRL_HINT_KEYS = [ - "control_input_edge", - "control_input_vis", - "control_input_depth", - "control_input_seg", - "control_input_keypoint", - "control_input_upscale", - "control_input_hdmap", - "control_input_lidar", -] - -CTRL_HINT_KEYS_COMB = { - "control_input_vis": [AddControlInput], - "control_input_edge": [AddControlInputEdge], - "control_input_depth": [AddControlInputDepth], - "control_input_seg": [AddControlInputSeg], - "control_input_keypoint": [AddControlInputKeypoint], - "control_input_upscale": [AddControlInputUpscale], - "control_input_hdmap": [AddControlInputHDMAP], - "control_input_lidar": [AddControlInputLIDAR], -} - -# SS=self-supervised -SS_CTRL_HINT_KEYS = [ - "control_input_canny", - "control_input_canny_blur", - "control_input_blur", - "control_input_upscale", -] - -# for data loading. Defining corresponding sub-folders in the data folder -CTRL_AUG_KEYS = { - "depth": "depth", - "seg": "segmentation", - "keypoint": "keypoint", -} - - -@attrs.define(slots=False) -class FrameRepeatConfig: - obj: LazyDict = L(FrameRepeatAttr)() - dropout_rate: float = 0.0 - input_key: str = "frame_repeat" - - -BaseVideoConditionerWithCtrlConfig: LazyDict = L(VideoConditionerWithCtrl)( - text=TextConfig(), -) - -VideoConditionerFpsSizePaddingWithCtrlConfig: LazyDict = L(VideoConditionerWithCtrl)( - text=TextConfig(), - fps=FPSConfig(), - num_frames=NumFramesConfig(), - image_size=ImageSizeConfig(), - padding_mask=PaddingMaskConfig(), - video_cond_bool=VideoCondBoolConfig(), -) - -ViewConditionedVideoConditionerFpsSizePaddingWithCtrlConfig: LazyDict = L(ViewConditionedVideoConditionerWithCtrl)( - text=TextConfig(), - fps=FPSConfig(), - num_frames=NumFramesConfig(), - image_size=ImageSizeConfig(), - padding_mask=PaddingMaskConfig(), - video_cond_bool=VideoCondBoolConfig(), - frame_repeat=FrameRepeatConfig(), -) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/config.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/config.py deleted file mode 100644 index b0db759934769d46150958d460450988757b068a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/config.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, List - -import attrs - -from cosmos_transfer1.diffusion.config.transfer.model import CtrlModelConfig -from cosmos_transfer1.diffusion.config.transfer.registry import register_configs -from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl -from cosmos_transfer1.utils import config -from cosmos_transfer1.utils.config_helper import import_all_modules_from_package -from cosmos_transfer1.utils.lazy_config import PLACEHOLDER -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict - - -@attrs.define(slots=False) -class Config(config.Config): - # default config groups that will be used unless overwritten - # see config groups in registry.py - defaults: List[Any] = attrs.field( - factory=lambda: [ - "_self_", - {"net": None}, - {"net_ctrl": None}, - {"hint_key": "control_input_edge"}, - {"conditioner": "ctrlnet_add_fps_image_size_padding_mask"}, - {"tokenizer": "vae1"}, - {"experiment": None}, - ] - ) - model_obj: LazyDict = L(VideoDiffusionModelWithCtrl)( - config=PLACEHOLDER, - ) - - -def make_config(): - c = Config( - model=CtrlModelConfig(), - ) - register_configs() - - import_all_modules_from_package("cosmos_transfer1.diffusion.config.inference") - return c diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/model.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/model.py deleted file mode 100644 index 9a8860599911251173d9a9ff39a15734818f27fe..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/model.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import attrs - -from cosmos_transfer1.diffusion.config.base.model import DefaultModelConfig -from cosmos_transfer1.utils.lazy_config import LazyDict - - -@attrs.define(slots=False) -class CtrlModelConfig(DefaultModelConfig): - net_ctrl: LazyDict = None - hint_key: str = None - base_load_from: LazyDict = None - finetune_base_model: bool = False - hint_mask: list = [True] - hint_dropout_rate: float = 0.0 - num_control_blocks: int = 3 - random_drop_control_blocks: bool = False - pixel_corruptor: LazyDict = None - n_views: int = 1 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/net_ctrl.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/net_ctrl.py deleted file mode 100644 index 712b33e60e775eefe9c4ef38a645325cab399b4c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/net_ctrl.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy - -from cosmos_transfer1.diffusion.config.base.net import FADITV2Config -from cosmos_transfer1.diffusion.networks.general_dit_ctrl_enc import GeneralDITEncoder -from cosmos_transfer1.diffusion.networks.general_dit_ctrl_enc_multiview import GeneralDITMultiviewEncoder - -num_blocks = FADITV2Config["num_blocks"] -FADITV2EncoderConfig = copy.deepcopy(FADITV2Config) -FADITV2EncoderConfig["_target_"] = GeneralDITEncoder -FADITV2EncoderConfig["layer_mask"] = [True if i > num_blocks // 2 else False for i in range(num_blocks)] - -num_blocks = FADITV2Config["num_blocks"] -FADITV2MultiCamEncoderConfig = copy.deepcopy(FADITV2Config) -FADITV2MultiCamEncoderConfig["_target_"] = GeneralDITMultiviewEncoder -FADITV2MultiCamEncoderConfig["layer_mask"] = [True if i > num_blocks // 2 else False for i in range(num_blocks)] diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/registry.py b/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/registry.py deleted file mode 100644 index 2c196340b093608ab9092f28c1e375f183d10b21..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/config/transfer/registry.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from hydra.core.config_store import ConfigStore - -import cosmos_transfer1.diffusion.config.registry as base_registry -from cosmos_transfer1.diffusion.config.transfer.conditioner import ( - CTRL_HINT_KEYS, - BaseVideoConditionerWithCtrlConfig, - VideoConditionerFpsSizePaddingWithCtrlConfig, - ViewConditionedVideoConditionerFpsSizePaddingWithCtrlConfig, -) -from cosmos_transfer1.diffusion.config.transfer.net_ctrl import FADITV2EncoderConfig, FADITV2MultiCamEncoderConfig - - -def register_experiment_ctrlnet(cs): - """ - transfer model related registry: controlnet architecture, hint keys, etc. - """ - # TODO: maybe we should change the registered 'name' (faditv2_7b) here; it's the dit-encoder for net_ctrl - # but current naming is the same as the full DiT in the main 'net' group that's defined - # in cosmos_transfer1/diffusion/config/registry.py. Isn't an error but could be confusing. - cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_7b", node=FADITV2EncoderConfig) - cs.store(group="net_ctrl", package="model.net_ctrl", name="faditv2_7b_mv", node=FADITV2MultiCamEncoderConfig) - - cs.store(group="conditioner", package="model.conditioner", name="ctrlnet", node=BaseVideoConditionerWithCtrlConfig) - cs.store( - group="conditioner", - package="model.conditioner", - name="ctrlnet_add_fps_image_size_padding_mask", - node=VideoConditionerFpsSizePaddingWithCtrlConfig, - ) - cs.store( - group="conditioner", - package="model.conditioner", - name="view_cond_ctrlnet_add_fps_image_size_padding_mask", - node=ViewConditionedVideoConditionerFpsSizePaddingWithCtrlConfig, - ) - for hint_key in CTRL_HINT_KEYS: - cs.store( - group="hint_key", - package="model", - name=hint_key, - node=dict(hint_key=dict(hint_key=hint_key, grayscale=False)), - ) - cs.store( - group="hint_key", - package="model", - name=f"{hint_key}_grayscale", - node=dict(hint_key=dict(hint_key=hint_key, grayscale=True)), - ) - - -def register_configs(): - cs = ConfigStore.instance() - base_registry.register_configs() - register_experiment_ctrlnet(cs) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentor_provider.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentor_provider.py deleted file mode 100644 index 08c146276523928baa3ceda795c42ee266432f8a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentor_provider.py +++ /dev/null @@ -1,142 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.diffusion.config.transfer.blurs import BlurAugmentorConfig, random_blur_config -from cosmos_transfer1.diffusion.config.transfer.conditioner import CTRL_AUG_KEYS, CTRL_HINT_KEYS, CTRL_HINT_KEYS_COMB -from cosmos_transfer1.diffusion.datasets.augmentors.basic_augmentors import ( - ReflectionPadding, - ResizeLargestSideAspectPreserving, -) -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import ( - VIDEO_RES_SIZE_INFO, - AddControlInput, - AddControlInputComb, -) -from cosmos_transfer1.diffusion.datasets.augmentors.merge_datadict import DataDictMerger -from cosmos_transfer1.utils.lazy_config import LazyCall as L - -AUGMENTOR_OPTIONS = {} - - -def augmentor_register(key): - def decorator(func): - AUGMENTOR_OPTIONS[key] = func - return func - - return decorator - - -@augmentor_register("video_basic_augmentor") -def get_video_augmentor( - resolution: str, - blur_config=None, -): - return { - "merge_datadict": L(DataDictMerger)( - input_keys=["video"], - output_keys=[ - "video", - "fps", - "num_frames", - "frame_start", - "frame_end", - "orig_num_frames", - ], - ), - "resize_largest_side_aspect_ratio_preserving": L(ResizeLargestSideAspectPreserving)( - input_keys=["video"], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, - ), - "reflection_padding": L(ReflectionPadding)( - input_keys=["video"], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, - ), - } - - -""" -register all the video ctrlnet augmentors for data loading -""" -for hint_key in CTRL_HINT_KEYS: - - def get_video_ctrlnet_augmentor(hint_key, use_random=True): - def _get_video_ctrlnet_augmentor( - resolution: str, - blur_config: BlurAugmentorConfig = random_blur_config, - ): - if hint_key == "control_input_keypoint": - add_control_input = L(AddControlInputComb)( - input_keys=["", "video"], - output_keys=[hint_key], - args={ - "comb": CTRL_HINT_KEYS_COMB[hint_key], - "use_openpose_format": True, - "kpt_thr": 0.6, - "human_kpt_line_width": 4, - }, - use_random=use_random, - blur_config=blur_config, - ) - elif hint_key in CTRL_HINT_KEYS_COMB: - add_control_input = L(AddControlInputComb)( - input_keys=["", "video"], - output_keys=[hint_key], - args={"comb": CTRL_HINT_KEYS_COMB[hint_key]}, - use_random=use_random, - blur_config=blur_config, - ) - else: - add_control_input = L(AddControlInput)( - input_keys=["", "video"], - output_keys=[hint_key], - use_random=use_random, - blur_config=blur_config, - ) - input_keys = ["video"] - output_keys = [ - "video", - "fps", - "num_frames", - "frame_start", - "frame_end", - "orig_num_frames", - ] - for key, value in CTRL_AUG_KEYS.items(): - if key in hint_key: - input_keys.append(value) - output_keys.append(value) - - augmentation = { - # "merge_datadict": L(DataDictMerger)( - # input_keys=input_keys, - # output_keys=output_keys, - # ), - # this addes the control input tensor to the data dict - "add_control_input": add_control_input, - # this resizes both the video and the control input to the model's required input size - "resize_largest_side_aspect_ratio_preserving": L(ResizeLargestSideAspectPreserving)( - input_keys=["video", hint_key], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, - ), - "reflection_padding": L(ReflectionPadding)( - input_keys=["video", hint_key], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, - ), - } - return augmentation - - return _get_video_ctrlnet_augmentor - - augmentor_register(f"video_ctrlnet_augmentor_{hint_key}")(get_video_ctrlnet_augmentor(hint_key)) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors.py deleted file mode 100644 index bd855ea07b10c9b937bdc94089924342e3c5185c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors.py +++ /dev/null @@ -1,240 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import omegaconf -import torch -import torchvision.transforms.functional as transforms_F - -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor -from cosmos_transfer1.diffusion.datasets.dataset_utils import obtain_augmentation_size, obtain_image_size - - -class ReflectionPadding(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs reflection padding. This function also returns a padding mask. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are center cropped. - """ - - assert self.args is not None, "Please specify args in augmentation" - if self.output_keys is None: - self.output_keys = self.input_keys - - # Obtain image and augmentation sizes - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - target_size = obtain_augmentation_size(data_dict, self.args) - - assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" - target_w, target_h = target_size - - target_w = int(target_w) - target_h = int(target_h) - - # Calculate padding vals - padding_left = int((target_w - orig_w) / 2) - padding_right = target_w - orig_w - padding_left - padding_top = int((target_h - orig_h) / 2) - padding_bottom = target_h - orig_h - padding_top - padding_vals = [padding_left, padding_top, padding_right, padding_bottom] - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: - # In this case, we can't perform reflection padding. This is because padding values - # are larger than the image size. So, perform edge padding instead. - data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") - else: - # Perform reflection padding - data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") - - if out_key != inp_key: - del data_dict[inp_key] - - # Return padding_mask when padding is performed. - # Padding mask denotes which pixels are padded. - padding_mask = torch.ones((1, target_h, target_w)) - padding_mask[:, padding_top : (padding_top + orig_h), padding_left : (padding_left + orig_w)] = 0 - data_dict["padding_mask"] = padding_mask - data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) - - return data_dict - - -class ResizeSmallestSide(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs resizing to smaller side - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - out_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance(out_size, int), "Arg size in resize should be an integer" - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=out_size, # type: ignore - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - if out_key != inp_key: - del data_dict[inp_key] - return data_dict - - -class ResizeLargestSide(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs resizing to larger side - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - out_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance(out_size, int), "Arg size in resize should be an integer" - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - - scaling_ratio = min(out_size / orig_w, out_size / orig_h) - target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] - - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=target_size, - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - if out_key != inp_key: - del data_dict[inp_key] - return data_dict - - -class ResizeSmallestSideAspectPreserving(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs aspect-ratio preserving resizing. - Image is resized to the dimension which has the smaller ratio of (size / target_size). - First we compute (w_img / w_target) and (h_img / h_target) and resize the image - to the dimension that has the smaller of these ratios. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - img_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance( - img_size, (tuple, omegaconf.listconfig.ListConfig) - ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" - img_w, img_h = img_size - - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) - target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) - - assert ( - target_size[0] >= img_h and target_size[1] >= img_w - ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=target_size, # type: ignore - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - - if out_key != inp_key: - del data_dict[inp_key] - return data_dict - - -class ResizeLargestSideAspectPreserving(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs aspect-ratio preserving resizing. - Image is resized to the dimension which has the larger ratio of (size / target_size). - First we compute (w_img / w_target) and (h_img / h_target) and resize the image - to the dimension that has the larger of these ratios. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - img_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance( - img_size, (tuple, omegaconf.listconfig.ListConfig) - ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" - img_w, img_h = img_size - - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) - target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) - - assert ( - target_size[0] <= img_h and target_size[1] <= img_w - ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=target_size, # type: ignore - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - - if out_key != inp_key: - del data_dict[inp_key] - return data_dict diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/basic_augmentors.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/basic_augmentors.py deleted file mode 100644 index bd855ea07b10c9b937bdc94089924342e3c5185c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/basic_augmentors.py +++ /dev/null @@ -1,240 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import omegaconf -import torch -import torchvision.transforms.functional as transforms_F - -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor -from cosmos_transfer1.diffusion.datasets.dataset_utils import obtain_augmentation_size, obtain_image_size - - -class ReflectionPadding(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs reflection padding. This function also returns a padding mask. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are center cropped. - """ - - assert self.args is not None, "Please specify args in augmentation" - if self.output_keys is None: - self.output_keys = self.input_keys - - # Obtain image and augmentation sizes - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - target_size = obtain_augmentation_size(data_dict, self.args) - - assert isinstance(target_size, (tuple, omegaconf.listconfig.ListConfig)), "Please specify target size as tuple" - target_w, target_h = target_size - - target_w = int(target_w) - target_h = int(target_h) - - # Calculate padding vals - padding_left = int((target_w - orig_w) / 2) - padding_right = target_w - orig_w - padding_left - padding_top = int((target_h - orig_h) / 2) - padding_bottom = target_h - orig_h - padding_top - padding_vals = [padding_left, padding_top, padding_right, padding_bottom] - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - if max(padding_vals[0], padding_vals[2]) >= orig_w or max(padding_vals[1], padding_vals[3]) >= orig_h: - # In this case, we can't perform reflection padding. This is because padding values - # are larger than the image size. So, perform edge padding instead. - data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="edge") - else: - # Perform reflection padding - data_dict[out_key] = transforms_F.pad(data_dict[inp_key], padding_vals, padding_mode="reflect") - - if out_key != inp_key: - del data_dict[inp_key] - - # Return padding_mask when padding is performed. - # Padding mask denotes which pixels are padded. - padding_mask = torch.ones((1, target_h, target_w)) - padding_mask[:, padding_top : (padding_top + orig_h), padding_left : (padding_left + orig_w)] = 0 - data_dict["padding_mask"] = padding_mask - data_dict["image_size"] = torch.tensor([target_h, target_w, orig_h, orig_w], dtype=torch.float) - - return data_dict - - -class ResizeSmallestSide(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs resizing to smaller side - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - out_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance(out_size, int), "Arg size in resize should be an integer" - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=out_size, # type: ignore - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - if out_key != inp_key: - del data_dict[inp_key] - return data_dict - - -class ResizeLargestSide(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs resizing to larger side - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - out_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance(out_size, int), "Arg size in resize should be an integer" - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - - scaling_ratio = min(out_size / orig_w, out_size / orig_h) - target_size = [int(scaling_ratio * orig_h), int(scaling_ratio * orig_w)] - - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=target_size, - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - if out_key != inp_key: - del data_dict[inp_key] - return data_dict - - -class ResizeSmallestSideAspectPreserving(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs aspect-ratio preserving resizing. - Image is resized to the dimension which has the smaller ratio of (size / target_size). - First we compute (w_img / w_target) and (h_img / h_target) and resize the image - to the dimension that has the smaller of these ratios. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - img_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance( - img_size, (tuple, omegaconf.listconfig.ListConfig) - ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" - img_w, img_h = img_size - - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) - target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) - - assert ( - target_size[0] >= img_h and target_size[1] >= img_w - ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=target_size, # type: ignore - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - - if out_key != inp_key: - del data_dict[inp_key] - return data_dict - - -class ResizeLargestSideAspectPreserving(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs aspect-ratio preserving resizing. - Image is resized to the dimension which has the larger ratio of (size / target_size). - First we compute (w_img / w_target) and (h_img / h_target) and resize the image - to the dimension that has the larger of these ratios. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - img_size = obtain_augmentation_size(data_dict, self.args) - assert isinstance( - img_size, (tuple, omegaconf.listconfig.ListConfig) - ), f"Arg size in resize should be a tuple, get {type(img_size)}, {img_size}" - img_w, img_h = img_size - - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - scaling_ratio = min((img_w / orig_w), (img_h / orig_h)) - target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) - - assert ( - target_size[0] <= img_h and target_size[1] <= img_w - ), f"Resize error. orig {(orig_w, orig_h)} desire {img_size} compute {target_size}" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=target_size, # type: ignore - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - - if out_key != inp_key: - del data_dict[inp_key] - return data_dict diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py deleted file mode 100644 index cfd8f2765f18260ad2c4e4e51025e70d41aa25ff..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/control_input.py +++ /dev/null @@ -1,1469 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -from functools import partial -from typing import Any, Optional - -import cv2 -import matplotlib.colors as mcolors -import numpy as np -import pycocotools.mask -import torch -import torchvision.transforms.functional as transforms_F - -from cosmos_transfer1.diffusion.config.transfer.blurs import ( - AnisotropicDiffusionConfig, - BilateralFilterConfig, - BlurAugmentorConfig, - GaussianBlurConfig, - GuidedFilterConfig, - LaplacianOfGaussianConfig, - MedianBlurConfig, -) -from cosmos_transfer1.diffusion.datasets.augmentors.guided_filter import FastGuidedFilter -from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import ( - coco_wholebody_133_skeleton, - convert_coco_to_openpose, - openpose134_skeleton, -) -from cosmos_transfer1.utils import log - -IMAGE_RES_SIZE_INFO: dict[str, tuple[int, int]] = { - "1080": { # the image format does not support 1080, but here we match it with video resolution - "1,1": (1024, 1024), - "4,3": (1440, 1056), - "3,4": (1056, 1440), - "16,9": (1920, 1056), - "9,16": (1056, 1920), - }, - "1024": {"1,1": (1024, 1024), "4,3": (1280, 1024), "3,4": (1024, 1280), "16,9": (1280, 768), "9,16": (768, 1280)}, - # 720; mainly for make sure it matches video resolution conventions - "720": {"1,1": (960, 960), "4,3": (960, 704), "3,4": (704, 960), "16,9": (1280, 704), "9,16": (704, 1280)}, - "512": {"1,1": (512, 512), "4,3": (640, 512), "3,4": (512, 640), "16,9": (640, 384), "9,16": (384, 640)}, - "256": { - "1,1": (256, 256), - "4,3": (320, 256), - "3,4": (256, 320), - "16,9": (320, 192), - "9,16": (192, 320), - }, -} - - -VIDEO_RES_SIZE_INFO: dict[str, tuple[int, int]] = { - "1080": { # 1080p doesn't have 1:1 - "1,1": (1024, 1024), - "4,3": (1440, 1056), - "3,4": (1056, 1440), - "16,9": (1920, 1056), - "9,16": (1056, 1920), - }, - "720": {"1,1": (960, 960), "4,3": (960, 704), "3,4": (704, 960), "16,9": (1280, 704), "9,16": (704, 1280)}, - "512": {"1,1": (512, 512), "4,3": (640, 512), "3,4": (512, 640), "16,9": (640, 384), "9,16": (384, 640)}, - "480": {"1,1": (480, 480), "4,3": (640, 480), "3,4": (480, 640), "16,9": (768, 432), "9,16": (432, 768)}, - "256": { - "1,1": (256, 256), - "4,3": (320, 256), - "3,4": (256, 320), - "16,9": (320, 192), - "9,16": (192, 320), - }, -} - - -class Augmentor: - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - r"""Base augmentor class - - Args: - input_keys (list): List of input keys - output_keys (list): List of output keys - args (dict): Arguments associated with the augmentation - """ - self.input_keys = input_keys - self.output_keys = output_keys - self.args = args - - def __call__(self, *args: Any, **kwds: Any) -> Any: - raise ValueError("Augmentor not implemented") - - -def resize_frames(frames, is_image, data_dict): - # Resize the frames to target size before computing control signals to save compute. - need_reshape = len(frames.shape) < 4 - if need_reshape: # HWC -> CTHW - frames = frames.transpose((2, 0, 1))[:, None] - H, W = frames.shape[2], frames.shape[3] - - if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: - aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] - elif "aspect_ratio" in data_dict: # Non-webdataset format - aspect_ratio = data_dict["aspect_ratio"] - else: - aspect_ratio = "16,9" - RES_SIZE_INFO = IMAGE_RES_SIZE_INFO if is_image else VIDEO_RES_SIZE_INFO - new_W, new_H = RES_SIZE_INFO["720"][aspect_ratio] - scaling_ratio = min((new_W / W), (new_H / H)) - if scaling_ratio < 1: - W, H = int(scaling_ratio * W + 0.5), int(scaling_ratio * H + 0.5) - frames = [ - cv2.resize(_image_np, (W, H), interpolation=cv2.INTER_AREA) for _image_np in frames.transpose((1, 2, 3, 0)) - ] - frames = np.stack(frames).transpose((3, 0, 1, 2)) - if need_reshape: # CTHW -> HWC - frames = frames[:, 0].transpose((1, 2, 0)) - return frames - - -# frames CTHW -def apply_gaussian_blur(frames: np.ndarray, ksize: int = 5, sigmaX: float = 1.0) -> np.ndarray: - if ksize % 2 == 0: - ksize += 1 # ksize must be odd - blurred_image = [ - cv2.GaussianBlur(_image_np, (ksize, ksize), sigmaX=sigmaX) for _image_np in frames.transpose((1, 2, 3, 0)) - ] - blurred_image = np.stack(blurred_image).transpose((3, 0, 1, 2)) - return blurred_image - - -class GaussianBlur: - def __init__(self, config: GaussianBlurConfig) -> None: - self.use_random = config.use_random - self.config = config - - def __call__(self, frames: np.ndarray) -> np.ndarray: - if self.use_random: - ksize = np.random.randint(self.config.ksize_min, self.config.ksize_max + 1) - sigmaX = np.random.uniform(self.config.sigmaX_min, self.config.sigmaX_max) - else: - ksize = self.config.ksize - sigmaX = self.config.sigmaX - return apply_gaussian_blur(frames, ksize, sigmaX) - - -def apply_guided_filter(frames: np.ndarray, radius: int, eps: float, scale: float) -> np.ndarray: - blurred_image = [ - FastGuidedFilter(_image_np, radius, eps, scale).filter(_image_np) - for _image_np in frames.transpose((1, 2, 3, 0)) - ] - blurred_image = np.stack(blurred_image).transpose((3, 0, 1, 2)) - return blurred_image - - -class GuidedFilter: - def __init__(self, config: GuidedFilterConfig) -> None: - self.use_random = config.use_random - self.config = config - - def __call__(self, frames: np.ndarray) -> np.ndarray: - if self.use_random: - radius = np.random.randint(self.config.radius_min, self.config.radius_max + 1) - eps = np.random.uniform(self.config.eps_min, self.config.eps_max) - scale = np.random.randint(self.config.scale_min, self.config.scale_max + 1) - else: - radius = self.config.radius - eps = self.config.eps - scale = self.config.scale - return apply_guided_filter(frames, radius, eps, scale) - - -def apply_bilateral_filter( - frames: np.ndarray, - d: int = 9, - sigma_color: float = 75, - sigma_space: float = 75, - iter: int = 1, -) -> np.ndarray: - blurred_image = [] - for _image_np in frames.transpose((1, 2, 3, 0)): - for _ in range(iter): - _image_np = cv2.bilateralFilter(_image_np, d, sigma_color, sigma_space) - blurred_image += [_image_np] - - blurred_image = np.stack(blurred_image).transpose((3, 0, 1, 2)) - return blurred_image - - -class BilateralFilter: - def __init__(self, config: BilateralFilterConfig) -> None: - self.use_random = config.use_random - self.config = config - - def __call__(self, frames: np.ndarray) -> np.ndarray: - config = self.config - if self.use_random: - d = np.random.randint(config.d_min, config.d_max) - sigma_color = np.random.randint(config.sigma_color_min, config.sigma_color_max) - sigma_space = np.random.randint(config.sigma_space_min, config.sigma_space_max) - iter = np.random.randint(config.iter_min, config.iter_max) - else: - d = config.d - sigma_color = config.sigma_color - sigma_space = config.sigma_space - iter = config.iter - return apply_bilateral_filter(frames, d, sigma_color, sigma_space, iter) - - -def apply_median_blur(frames: np.ndarray, ksize=5) -> np.ndarray: - if ksize % 2 == 0: - ksize += 1 # ksize must be odd - blurred_image = [cv2.medianBlur(_image_np, ksize) for _image_np in frames.transpose((1, 2, 3, 0))] - blurred_image = np.stack(blurred_image).transpose((3, 0, 1, 2)) - return blurred_image - - -class MedianBlur: - def __init__(self, config: MedianBlurConfig) -> None: - self.use_random = config.use_random - self.config = config - - def __call__(self, frames: np.ndarray) -> np.ndarray: - if self.use_random: - ksize = np.random.randint(self.config.ksize_min, self.config.ksize_max + 1) - else: - ksize = self.config.ksize - return apply_median_blur(frames, ksize) - - -def apply_laplacian_of_gaussian( - frames: np.ndarray, ksize: int = 5, sigma: float = 1.4, binarize: bool = False, threshold: float = 0.0 -) -> np.ndarray: - """ - Apply Laplacian of Gaussian edge detection to a set of frames. - - Args: - frames (np.ndarray): Input frames with shape (C, T, H, W) - ksize (int): Size of the Gaussian kernel. Must be odd and positive. - sigma (float): Standard deviation of the Gaussian distribution. - binarize (bool): Whether to binarize the output edge map. - threshold (float): Threshold for binarization (if binarize is True). - - Returns: - np.ndarray: Edge-detected frames with shape (C, T, H, W). - """ - # Ensure ksize is odd - if ksize % 2 == 0: - ksize += 1 # ksize must be odd - - edge_frames = [] - for frame in frames.transpose((1, 2, 3, 0)): # (T, H, W, C) - # Convert to grayscale if the image is in color - if frame.shape[-1] == 3: - gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - else: - gray = frame.squeeze() - - blurred = cv2.GaussianBlur(gray, (ksize, ksize), sigma) - laplacian = cv2.Laplacian(blurred, cv2.CV_64F) - normalized = cv2.normalize(laplacian, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) - - if binarize: - _, edge_map = cv2.threshold(normalized, threshold, 255, cv2.THRESH_BINARY) - else: - edge_map = normalized - - # Expand dimensions to match input shape - edge_map = np.repeat(edge_map[..., np.newaxis], frames.shape[0], axis=-1) - edge_frames.append(edge_map) - return np.stack(edge_frames).transpose((3, 0, 1, 2)) # (C, T, H, W) - - -class LaplacianOfGaussian: - """ - Applies Laplacian of Gaussian edge detection to images or video frames. - """ - - def __init__(self, config: LaplacianOfGaussianConfig) -> None: - self.use_random = config.use_random - self.config = config - - def __call__(self, frames: np.ndarray) -> np.ndarray: - """Apply LoG to input frames.""" - if self.use_random: - ksize = np.random.randint(self.config.ksize_min, self.config.ksize_max + 1) - sigma = np.random.uniform(self.config.sigma_min, self.config.sigma_max) - binarize = np.random.choice([True, False]) if not self.config.binarize else self.config.binarize - threshold = np.random.uniform(self.config.threshold_min, self.config.threshold_max) if binarize else 0 - else: - ksize = self.config.ksize - sigma = self.config.sigma - binarize = self.config.binarize - threshold = self.config.threshold - return apply_laplacian_of_gaussian(frames, ksize, sigma, binarize, threshold) - - -def apply_anisotropic_diffusion(frames: np.ndarray, alpha: float, K: float, niters: int) -> np.ndarray: - """ - Apply Anisotropic Diffusion to a set of frames. - - Args: - frames (np.ndarray): Input frames with shape (C, T, H, W) - alpha (float): The amount of time to step forward on each iteration (between 0 and 1) - K (float): Sensitivity to edges - niters (int): Number of iterations - - Returns: - np.ndarray: Anisotropic-diffused frames with shape (C, T, H, W). - """ - blurred_image = [ - cv2.ximgproc.anisotropicDiffusion(_image_np, alpha, K, niters) for _image_np in frames.transpose((1, 2, 3, 0)) - ] - blurred_image = np.stack(blurred_image).transpose((3, 0, 1, 2)) - - return blurred_image - - -class AnisotropicDiffusion: - """ - Applies Anisotropic Diffusion to images or video frames. - """ - - def __init__(self, config: AnisotropicDiffusionConfig) -> None: - self.use_random = config.use_random - self.config = config - - def __call__(self, frames: np.ndarray) -> np.ndarray: - if self.use_random: - alpha = np.random.uniform(self.config.alpha_min, self.config.alpha_max) - K = np.random.uniform(self.config.K_min, self.config.K_max) - niters = np.random.randint(self.config.niters_min, self.config.niters_max + 1) - else: - alpha = self.config.alpha - K = self.config.K - niters = self.config.niters - return apply_anisotropic_diffusion(frames, alpha, K, niters) - - -class Blur: - def __init__(self, config: BlurAugmentorConfig, output_key: str = "") -> None: - self.output_key = output_key if output_key else None - - probabilities = [combo.probability for combo in config.blur_combinations] - total_prob = sum(probabilities) - assert abs(total_prob - 1.0) < 1e-6, f"Probabilities must sum to 1.0, got {total_prob}" - - self.blur_combinations = config.blur_combinations - self.downscale_factor = config.downscale_factor - self.probabilities = probabilities - self._set_blur_instances() - - def _set_blur_instances(self): - if not self.blur_combinations: - return - self.blur_combinations_instances = [] - - for blur_combination in self.blur_combinations: - blur_mapping = { - "gaussian": (GaussianBlur, blur_combination.gaussian_blur), - "guided": (GuidedFilter, blur_combination.guided_filter), - "bilateral": (BilateralFilter, blur_combination.bilateral_filter), - "median": (MedianBlur, blur_combination.median_blur), - "log": (LaplacianOfGaussian, blur_combination.log), - "anisotropic": (AnisotropicDiffusion, blur_combination.anisotropic_diffusion), - } - - cur_instances = [] - for blur_type in blur_combination.blur_types: - assert blur_type in blur_mapping, f"Unknown {blur_type}. Needs to correct blur_type or blur_mapping." - - blur_class, blur_config = blur_mapping[blur_type] - cur_instances.append(blur_class(blur_config)) - - self.blur_combinations_instances.append(cur_instances) - - assert len(self.blur_combinations_instances) == len( - self.blur_combinations - ), "Number of blur_combinations_instances needs to match number of blur_combinations." - - def __call__(self, frames: np.ndarray) -> np.ndarray: - blur_instances = random.choices(self.blur_combinations_instances, weights=self.probabilities, k=1)[0] - - H, W = frames.shape[2], frames.shape[3] - downscale_factor = random.choice(self.downscale_factor) - if downscale_factor > 1: - frames = [ - cv2.resize(_image_np, (W // downscale_factor, H // downscale_factor), interpolation=cv2.INTER_AREA) - for _image_np in frames.transpose((1, 2, 3, 0)) - ] - frames = np.stack(frames).transpose((3, 0, 1, 2)) - - for ins in blur_instances: - frames = ins(frames) - - if downscale_factor > 1: - frames = [ - cv2.resize(_image_np, (W, H), interpolation=cv2.INTER_LINEAR) - for _image_np in frames.transpose((1, 2, 3, 0)) - ] - frames = np.stack(frames).transpose((3, 0, 1, 2)) - return frames - - -class AddControlInputBlurDownUp(Augmentor): - """ - Main class for adding blurred input to the data dictionary. - self.output_keys[0] indicates the types of blur added to the input. - For example, control_input_gaussian_guided indicates that both Gaussian and Guided filters are applied - """ - - def __init__( - self, - input_keys: list, # [key_load, key_img] - output_keys: Optional[list] = [ - "control_input_gaussian_guided_bilateral_median_log" - ], # eg ["control_input_gaussian_guided"] - args: Optional[dict] = None, # not used - use_random: bool = True, # whether to use random parameters - blur_config: BlurAugmentorConfig = BlurAugmentorConfig(), - downup_preset: str | int = "medium", # preset strength for downup factor - min_downup_factor: int = 4, # minimum downup factor - max_downup_factor: int = 16, # maximum downup factor - downsize_before_blur: bool = False, # whether to downsize before applying blur and then upsize or downup after blur - ) -> None: - super().__init__(input_keys, output_keys, args) - self.use_random = use_random - downup_preset_values = { - "none": 1, - "very_low": min_downup_factor, - "low": min_downup_factor, - "medium": (min_downup_factor + max_downup_factor) // 2, - "high": max_downup_factor, - "very_high": max_downup_factor, - } - - self.blur = Blur(config=blur_config, output_key=self.output_keys[0]) - - self.downup_preset = downup_preset if isinstance(downup_preset, int) else downup_preset_values[downup_preset] - self.downsize_before_blur = downsize_before_blur - self.min_downup_factor = min_downup_factor - self.max_downup_factor = max_downup_factor - - def _load_frame(self, data_dict: dict) -> tuple[np.ndarray, bool]: - key_img = self.input_keys[1] - frames = data_dict[key_img] - frames = np.array(frames) - is_image = False - if len(frames.shape) < 4: - frames = frames.transpose((2, 0, 1))[:, None] - is_image = True - return frames, is_image - - def __call__(self, data_dict: dict) -> dict: - if "control_input_vis" in data_dict: - # already processed - return data_dict - key_img = self.input_keys[1] - key_out = self.output_keys[0] - frames, is_image = self._load_frame(data_dict) - - # Resize the frames to target size before blurring. - frames = resize_frames(frames, is_image, data_dict) - H, W = frames.shape[2], frames.shape[3] - - if self.use_random: - scale_factor = random.randint(self.min_downup_factor, self.max_downup_factor + 1) - else: - scale_factor = self.downup_preset - if self.downsize_before_blur: - frames = [ - cv2.resize(_image_np, (W // scale_factor, H // scale_factor), interpolation=cv2.INTER_AREA) - for _image_np in frames.transpose((1, 2, 3, 0)) - ] - frames = np.stack(frames).transpose((3, 0, 1, 2)) - frames = self.blur(frames) - if self.downsize_before_blur: - frames = [ - cv2.resize(_image_np, (W, H), interpolation=cv2.INTER_LINEAR) - for _image_np in frames.transpose((1, 2, 3, 0)) - ] - frames = np.stack(frames).transpose((3, 0, 1, 2)) - if is_image: - frames = frames[:, 0] - # turn into tensor - controlnet_img = torch.from_numpy(frames) - if not self.downsize_before_blur: - # Resize image - controlnet_img = transforms_F.resize( - controlnet_img, - size=(int(H / scale_factor), int(W / scale_factor)), - interpolation=transforms_F.InterpolationMode.BICUBIC, - antialias=True, - ) - controlnet_img = transforms_F.resize( - controlnet_img, - size=(H, W), - interpolation=transforms_F.InterpolationMode.BICUBIC, - antialias=True, - ) - data_dict[key_out] = controlnet_img - return data_dict - - -class AddControlInputEdge(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = ["control_input_edge"], - args: Optional[dict] = None, - use_random: bool = True, - preset_canny_threshold="medium", - **kwargs, - ) -> None: - super().__init__(input_keys, output_keys, args) - self.use_random = use_random - self.preset_strength = preset_canny_threshold - - def __call__(self, data_dict: dict) -> dict: - if "control_input_edge" in data_dict: - # already processed - return data_dict - key_img = self.input_keys[1] - key_out = self.output_keys[0] - frames = data_dict[key_img] - # Get lower and upper threshold for canny edge detection. - if self.use_random: # always on for training, always off for inference - t_lower = np.random.randint(20, 100) # Get a random lower thre within [0, 255] - t_diff = np.random.randint(50, 150) # Get a random diff between lower and upper - t_upper = min(255, t_lower + t_diff) # The upper thre is lower added by the diff - else: - if self.preset_strength == "none" or self.preset_strength == "very_low": - t_lower, t_upper = 20, 50 - elif self.preset_strength == "low": - t_lower, t_upper = 50, 100 - elif self.preset_strength == "medium": - t_lower, t_upper = 100, 200 - elif self.preset_strength == "high": - t_lower, t_upper = 200, 300 - elif self.preset_strength == "very_high": - t_lower, t_upper = 300, 400 - else: - raise ValueError(f"Preset {self.preset_strength} not recognized.") - frames = np.array(frames) - is_image = len(frames.shape) < 4 - - # Resize the frames to target size before computing canny edges. - frames = resize_frames(frames, is_image, data_dict) - - # Compute the canny edge map by the two thresholds. - if is_image: - edge_maps = cv2.Canny(frames, t_lower, t_upper)[None, None] - else: - edge_maps = [cv2.Canny(img, t_lower, t_upper) for img in frames.transpose((1, 2, 3, 0))] - edge_maps = np.stack(edge_maps)[None] - edge_maps = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) - if is_image: - edge_maps = edge_maps[:, 0] - data_dict[key_out] = edge_maps - return data_dict - - -class AddControlInput(Augmentor): - """ - For backward compatibility. The previously trained models use legacy_process - """ - - def __init__( - self, - input_keys: list, - output_keys=["control_input_vis"], - args=None, - blur_config: BlurAugmentorConfig = BlurAugmentorConfig(), - use_random=True, - preset_blur_strength="medium", - **kwargs, - ) -> None: - super().__init__(input_keys, output_keys, args) - - self.process = AddControlInputBlurDownUp( - input_keys, - output_keys, - args, - blur_config=blur_config, - downup_preset=preset_blur_strength, # preset strength for downup factor - use_random=use_random, - ) - - def __call__(self, data_dict: dict) -> dict: - return self.process(data_dict) - - -class AddControlInputComb(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = None, - blur_config: BlurAugmentorConfig = None, - args: Optional[dict] = None, - **kwargs, - ) -> None: - super().__init__(input_keys, output_keys, args) - assert "comb" in args - self.comb = {} - for class_name in args["comb"]: - if class_name in [AddControlInput, AddControlInputBlurDownUp]: - aug = class_name(input_keys=input_keys, args=args, blur_config=blur_config, **kwargs) - else: - aug = class_name(input_keys=input_keys, args=args, **kwargs) - - key = aug.output_keys[0] - self.comb[key] = aug - - def __call__(self, data_dict: dict) -> dict: - all_comb = [] - for k, v in self.comb.items(): - data_dict = v(data_dict) - all_comb.append(data_dict.pop(k)) - if all_comb[-1].dim() == 4: - all_comb[-1] = all_comb[-1].squeeze(1) - all_comb = torch.cat(all_comb, dim=0) - data_dict[self.output_keys[0]] = all_comb - return data_dict - - -def get_augmentor_for_eval( - input_key: str, - output_key: str, - blur_config: BlurAugmentorConfig = BlurAugmentorConfig(), - preset_blur_strength: str = "medium", - preset_canny_threshold: str = "medium", - blur_type: str = "gaussian,guided,bilateral,median,log,anisotropic", # do we still need this value? -) -> AddControlInputComb: - comb = [] - output_keys = output_key.replace("control_input_", "").split("_") - for key in output_keys: - if "edge" in key: - comb.append(partial(AddControlInputEdge, output_keys=["control_input_edge"])) - elif "upscale" in key: - comb.append(partial(AddControlInputUpscale, output_keys=["control_input_upscale"])) - elif "depth" in key: - comb.append(partial(AddControlInputDepth, output_keys=["control_input_depth"])) - elif "seg" in key: - comb.append(partial(AddControlInputSeg, output_keys=["control_input_seg"])) - elif "vis" in key: - comb.append(AddControlInput) - elif "keypoint" in key: - comb.append(partial(AddControlInputKeypoint, output_keys=["control_input_keypoint"])) - elif "hdmap" in key: - comb.append(partial(AddControlInputHDMAP, output_keys=["control_input_hdmap"])) - elif "lidar" in key: - comb.append(partial(AddControlInputLIDAR, output_keys=["control_input_lidar"])) - process = AddControlInputComb( - input_keys=["", input_key], - output_keys=[output_key], - args={"comb": comb}, - blur_config=blur_config, - use_random=False, - preset_blur_strength=preset_blur_strength, - preset_canny_threshold=preset_canny_threshold, - ) - return process - - -class AddControlInputDepth(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = ["control_input_depth"], - args: Optional[dict] = None, - **kwargs, - ) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - if "control_input_depth" in data_dict: - # already processed - return data_dict - if "video" not in data_dict: - image = np.array(data_dict[self.input_keys[1]]) - h, w, _ = image.shape - data_dict[self.output_keys[0]] = torch.from_numpy(np.zeros((3, h, w)).astype(np.uint8)) - return data_dict - - # assert data_dict["chunk_index"] == data_dict["depth"]["chunk_index"] - key_out = self.output_keys[0] - depth = data_dict["depth"]["video"] - data_dict[key_out] = depth - return data_dict - - -class AddControlInputHDMAP(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py. - For HDMAP condition, the hdmap should already been provided as input. So no extra process is needed. - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = ["control_input_hdmap"], - args: Optional[dict] = None, - **kwargs, - ) -> None: - self.output_keys = output_keys - self.input_keys = input_keys - - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict) -> dict: - if "control_input_hdmap" in data_dict: - # already processed - return data_dict - - key_out = self.output_keys[0] - hdmap = data_dict["hdmap"]["video"] - data_dict[key_out] = hdmap - return data_dict - - -class AddControlInputLIDAR(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py. - For LIDAR condition, the lidar should already been provided as input. So no extra process is needed. - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = ["control_input_lidar"], - args: Optional[dict] = None, - **kwargs, - ) -> None: - self.output_keys = output_keys - self.input_keys = input_keys - - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict) -> dict: - if "control_input_lidar" in data_dict: - # already processed - return data_dict - - key_out = self.output_keys[0] - lidar = data_dict["lidar"]["video"] - data_dict[key_out] = lidar - return data_dict - - -# Array of 23 highly distinguishable colors in RGB format -PREDEFINED_COLORS_SEGMENTATION = np.array( - [ - [255, 0, 0], # Red - [0, 255, 0], # Green - [0, 0, 255], # Blue - [255, 255, 0], # Yellow - [0, 255, 255], # Cyan - [255, 0, 255], # Magenta - [255, 140, 0], # Dark Orange - [255, 105, 180], # Hot Pink - [0, 0, 139], # Dark Blue - [0, 128, 128], # Teal - [75, 0, 130], # Indigo - [128, 0, 128], # Purple - [255, 69, 0], # Red-Orange - [34, 139, 34], # Forest Green - [128, 128, 0], # Olive - [70, 130, 180], # Steel Blue - [255, 215, 0], # Gold - [255, 222, 173], # Navajo White - [144, 238, 144], # Light Green - [255, 99, 71], # Tomato - [221, 160, 221], # Plum - [0, 255, 127], # Spring Green - [255, 255, 255], # White - ] -) - - -def generate_distinct_colors(): - """ - Generate `n` visually distinguishable and randomized colors. - - Returns: - np.ndarray, (3) - """ - # Randomize hue, saturation, and lightness within a range - hue = random.uniform(0, 1) # Full spectrum of hues - saturation = random.uniform(0.1, 1) # Vibrant colors - lightness = random.uniform(0.2, 1.0) # Avoid too dark - - r, g, b = mcolors.hsv_to_rgb((hue, saturation, lightness)) - return (np.array([r, g, b]) * 255).astype(np.uint8) - - -def segmentation_color_mask(segmentation_mask: np.ndarray, use_fixed_color_list: bool = False) -> np.ndarray: - """ - Convert segmentation mask to color mask - Args: - segmentation_mask: np.ndarray, shape (num_masks, T, H, W) - Returns: - np.ndarray, shape (3, T, H, W), with each mask converted to a color mask, value [0,255] - """ - - num_masks, T, H, W = segmentation_mask.shape - segmentation_mask_sorted = [segmentation_mask[i] for i in range(num_masks)] - # Sort the segmentation mask by the number of non-zero pixels, from most to least - segmentation_mask_sorted = sorted(segmentation_mask_sorted, key=lambda x: np.count_nonzero(x), reverse=True) - - output = np.zeros((3, T, H, W), dtype=np.uint8) - if use_fixed_color_list: - predefined_colors_permuted = PREDEFINED_COLORS_SEGMENTATION[ - np.random.permutation(len(PREDEFINED_COLORS_SEGMENTATION)) - ] - else: - predefined_colors_permuted = [generate_distinct_colors() for _ in range(num_masks)] - # index the segmentation mask from last channel to first channel, i start from num_masks-1 to 0 - for i in range(num_masks): - mask = segmentation_mask_sorted[i] - color = predefined_colors_permuted[i % len(predefined_colors_permuted)] - - # Create boolean mask and use it for assignment - bool_mask = mask > 0 - for c in range(3): - output[c][bool_mask] = color[c] - - return output - - -def decode_partial_rle_width1(rle_obj, start_row, end_row): - """ - Decode a partial RLE encoded mask with width = 1. In SAM2 output, the video mask (num_frame, height, width) are reshaped to (total_size, 1). - Sometimes the video mask could be large, e.g. 1001x1080x1092 shape and it takes >1GB memory if using pycocotools, resulting in segmentation faults when training with multiple GPUs and data workers. - This function is used to decode the mask for a subset of frames to reduce memory usage. - - Args: - rle_obj (dict): RLE object containing: - - 'size': A list [height, width=1] indicating the dimensions of the mask. - - 'counts': A bytes or string object containing the RLE encoded data. - start_row (int): The starting row (inclusive). It's computed from frame_start * height * width. - end_row (int): The ending row (exclusive). It's computed from frame_end * height * width. - - Returns: - numpy.ndarray: Decoded binary mask for the specified rows as a 1D numpy array. - """ - height, width = rle_obj["size"] - - # Validate row range - if width != 1: - raise ValueError("This function is optimized for width=1.") - if start_row < 0 or end_row > height or start_row >= end_row: - raise ValueError("Invalid row range specified.") - - # Decode the RLE counts - counts = rle_obj["counts"] - if isinstance(counts, str): - counts = np.frombuffer(counts.encode("ascii"), dtype=np.uint8) - elif isinstance(counts, bytes): - counts = np.frombuffer(counts, dtype=np.uint8) - else: - raise ValueError("Unsupported format for counts. Must be str or bytes.") - - # Interpret counts as a sequence of run lengths - run_lengths = [] - current_val = 0 - i = 0 - while i < len(counts): - x = 0 - k = 0 - more = True - while more: - c = counts[i] - 48 - x |= (c & 0x1F) << (5 * k) - more = (c & 0x20) != 0 - i += 1 - k += 1 - if not more and (c & 0x10): - x |= -1 << (5 * k) - if len(run_lengths) > 2: - x += run_lengths[-2] - - run_lengths.append(x) - current_val += x - if current_val > end_row: - break - # Initialize the partial mask - idx_start = start_row - idx_end = end_row - partial_mask = np.zeros(idx_end - idx_start, dtype=np.uint8) - partial_height = end_row - start_row - idx = 0 # Current global index - for i, run in enumerate(run_lengths): - run_start = idx - run_end = idx + run - if run_end <= idx_start: - # Skip runs entirely before the region - idx = run_end - continue - if run_start >= idx_end: - # Stop decoding once we pass the region - break - - # Calculate overlap with the target region - start = max(run_start, idx_start) - end = min(run_end, idx_end) - if start < end: - partial_start = start - idx_start - partial_end = end - idx_start - partial_mask[partial_start:partial_end] = i % 2 - - idx = run_end - return partial_mask.reshape((partial_height, 1), order="F") - - -class AddControlInputSeg(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = ["control_input_seg"], - thres_mb_python_decode: Optional[int] = 256, # required: <= 512 for 7b - use_fixed_color_list: bool = False, - num_masks_max: int = 100, - random_sample_num_masks: bool = True, - args: Optional[dict] = None, - **kwargs, - ) -> None: - """ - Args: - thres_mb_python_decode: int, threshold of memory usage for python decode, in MB - use_fixed_color_list: bool, if True, use predefined colors for segmentation masks. If False, generate random colors for segmentation masks. - num_masks_max: int, maximum number of masks to sample - random_sample_num_masks: bool, if True, sample number of masks randomly. If False, sample all masks in the data. - - """ - super().__init__(input_keys, output_keys, args) - self.use_fixed_color_list = use_fixed_color_list - self.num_masks_max = num_masks_max - self.thres_mb_python_decode = thres_mb_python_decode - self.random_sample_num_masks = random_sample_num_masks - - def __call__(self, data_dict: dict) -> dict: - if "control_input_seg" in data_dict: - # already processed - log.info( - f"control_input_seg already processed, shape={data_dict['control_input_seg'].shape}, dtype={data_dict['control_input_seg'].dtype}, value range: {data_dict['control_input_seg'].min()}, {data_dict['control_input_seg'].max()}" - ) - return data_dict - if "video" not in data_dict: - image = np.array(data_dict[self.input_keys[1]]) - h, w, _ = image.shape - data_dict[self.output_keys[0]] = torch.from_numpy(np.zeros((3, h, w)).astype(np.uint8)) - return data_dict - frames = data_dict["video"] - _, T, H, W = frames.shape - - all_masks = [] - # sample number of masks - if self.random_sample_num_masks: - num_masks = np.random.randint(0, min(self.num_masks_max + 1, len(data_dict["segmentation"]) + 1)) - else: - num_masks = len(data_dict["segmentation"]) - mask_ids = np.arange(len(data_dict["segmentation"])).tolist() - mask_ids_select = np.random.choice(mask_ids, num_masks, replace=False) - # concat phrases - segmentation_phrase_all = [data_dict["segmentation"][mid]["phrase"] for mid in mask_ids_select] - segmentation_phrase_all = ";".join(segmentation_phrase_all) - data_dict["segmentation_phrase_all"] = segmentation_phrase_all - # obtrain frame indices - frame_start = data_dict["frame_start"] - frame_end = data_dict["frame_end"] - frame_indices = np.arange(frame_start, frame_end).tolist() - assert ( - len(frame_indices) == T - ), f"frame_indices length {len(frame_indices)} != T {T}, likely due to video decoder using different fps, i.e. sample with stride. Need to return frame indices from video decoder." - all_masks = np.zeros((num_masks, T, H, W)).astype(np.uint8) - for idx, mid in enumerate(mask_ids_select): - mask = data_dict["segmentation"][mid] - shape = mask["segmentation_mask_rle"]["mask_shape"] - num_byte_per_mb = 1024 * 1024 - # total number of elements in uint8 (1 byte) / num_byte_per_mb - if shape[0] * shape[1] * shape[2] / num_byte_per_mb > self.thres_mb_python_decode: - # Switch to python decode if the mask is too large to avoid out of shared memory - - rle = decode_partial_rle_width1( - mask["segmentation_mask_rle"]["data"], - frame_start * shape[1] * shape[2], - frame_end * shape[1] * shape[2], - ) - partial_shape = (frame_end - frame_start, shape[1], shape[2]) - rle = rle.reshape(partial_shape) * 255 - else: - rle = pycocotools.mask.decode(mask["segmentation_mask_rle"]["data"]) - rle = rle.reshape(shape) * 255 - # Select the frames that are in the video - rle = np.stack([rle[i] for i in frame_indices]) - all_masks[idx] = rle - del rle - - key_out = self.output_keys[0] - # both value in [0,255] - # control_input_seg is the colored segmentation mask, value in [0,255], shape (3, T, H, W) - data_dict[key_out] = torch.from_numpy(segmentation_color_mask(all_masks, self.use_fixed_color_list)) - del all_masks # free memory - return data_dict - - -class AddControlInputKeypoint(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = ["control_input_keypoint"], - args: Optional[dict] = None, - **kwargs, - ) -> None: - super().__init__(input_keys, output_keys, args) - self.control_key_names = ["body-keypoints", "hand-keypoints"] - self.use_openpose_format = args.get("use_openpose_format", True) - self.hand_as_separate_channel = args.get("hand_as_separate_channel", False) - self.kpt_thr = args.get("kpt_thr", 0.6) - self.line_width = args.get("human_kpt_line_width", 4) - - def denormalize_pose_kpts(self, pose_kps: np.ndarray, h: int, w: int): - """ - pose_kps has shape = (#keypoints, 2) - or (#keypoints, 3) where the last dim is the confidence score. - """ - if pose_kps is not None: - assert pose_kps.shape[-1] == 3, "pose_kps must have shape (#keypoints, 3)" - out = pose_kps * np.array([w, h, 1]) - return out - else: - return None - - def draw_skeleton( - self, - img: np.ndarray, - keypoints: np.ndarray, - scores: np.ndarray, - kpt_thr: float = 0.6, - openpose_format: bool = False, - radius: int = 2, - line_width: int = 4, - ): - skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton - assert len(keypoints.shape) == 2 - keypoint_info, skeleton_info = ( - skeleton_topology["keypoint_info"], - skeleton_topology["skeleton_info"], - ) - vis_kpt = [s >= kpt_thr for s in scores] - - if self.hand_as_separate_channel: - img_hand = np.zeros_like(img) - hand_idx_start = 92 if openpose_format else 91 # all idx after this are hand keypoints - - link_dict = {} - for i, kpt_info in keypoint_info.items(): - kpt_color = tuple(kpt_info["color"]) - link_dict[kpt_info["name"]] = kpt_info["id"] - - kpt = keypoints[i] - - if vis_kpt[i]: - img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1) - - if self.hand_as_separate_channel: - if i >= hand_idx_start: - img_hand = cv2.circle(img_hand, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1) - - for i, ske_info in skeleton_info.items(): - link = ske_info["link"] - pt0, pt1 = link_dict[link[0]], link_dict[link[1]] - - if vis_kpt[pt0] and vis_kpt[pt1]: - link_color = ske_info["color"] - kpt0 = keypoints[pt0] - kpt1 = keypoints[pt1] - - img = cv2.line( - img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width - ) - - if self.hand_as_separate_channel: - if pt0 >= hand_idx_start and pt1 >= hand_idx_start: - img_hand = cv2.line( - img_hand, - (int(kpt0[0]), int(kpt0[1])), - (int(kpt1[0]), int(kpt1[1])), - link_color, - thickness=line_width, - ) - - if self.hand_as_separate_channel: - img = np.concatenate([img, img_hand], axis=-1) # [h,w,6] - return img - - def plot_person_kpts( - self, - person_dict: dict, - pose_vis_img: np.ndarray, - h: int, - w: int, - kpt_thr: float = 0.6, - openpose_format: bool = True, - line_width: int = 4, - ) -> np.ndarray: - """ - plot a single person - in-place update the pose image - """ - try: - body_keypoints = self.denormalize_pose_kpts(person_dict.get("body-keypoints"), h, w) - hand_keypoints = self.denormalize_pose_kpts(person_dict.get("hand-keypoints"), h, w) - except Exception as e: - log.error(f"Error in denormalizing keypoints: {e}") - - assert ( - body_keypoints is not None and hand_keypoints is not None - ), "Both body and hand keypoints must be present." - # all_keypoints: shape=(133, 3). following coco-fullbody skeleton config. 3 channels are x, y, confidence - all_keypoints = np.vstack([body_keypoints, hand_keypoints]) - kpts, scores = all_keypoints[..., :2], all_keypoints[..., -1] - if openpose_format: - kpts, scores = convert_coco_to_openpose(kpts, scores) - - try: - # [h,w,3] or # [h,w,6] if hand_as_separate_channel - pose_vis_img = self.draw_skeleton( - pose_vis_img, kpts, scores, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width - ) - except ValueError as e: - log.error(f"Error in draw_skeleton func, {e}") - - return pose_vis_img - - def plot_kpt_video( - self, - kpts_np_dict: dict, - h: int, - w: int, - kpt_thr: float = 0.6, - openpose_format: bool = True, - line_width: int = 4, - ) -> np.ndarray: - """ - Plots a single *frame* for all persons in the frame. - - The raw human keypoint annotation are numpy arrays of pixel coordinates of the joints. - This function plots the keypoints on a black background to form a 3-channel image compatible with controlnet. - - Args: - kpts_np_dict (dict): A dict of keypoint annotations. Each value is a frame's annotation (a list of per-person dict). - H (int): height of the image - W (int): width of the image - openpose_format (bool): whether the convert the coco-wholebody133 keypoints keypoints to openpose format and also - plot in the openpose format (basically add neck keypoint, remove toe keypoints). - Returns: - np.ndarray: keypoints of plotted on black background, shape = (C, T, H, W) C=3, or 6 if hand_as_separate_channel - """ - T = len(kpts_np_dict) - - out = np.empty((3, T, h, w), dtype=np.uint8) # memory save op - - for idx, (t, kpts_np_frame) in enumerate(kpts_np_dict.items()): - pose_vis_img = np.zeros([h, w, 3]) - - # add each person's keypoints to this frame's pose image - for person_dict in kpts_np_frame: - self.plot_person_kpts( - person_dict, - pose_vis_img, - h, - w, - kpt_thr=kpt_thr, - openpose_format=openpose_format, - line_width=line_width, - ) # (h, w, 3) - - out[:, idx, :, :] = pose_vis_img.astype(np.uint8).transpose(2, 0, 1) - - return out - - def get_kpts_from_annotations(self, annotation_dict: dict, total_frames: int, frame_indices: list) -> dict: - """ - For legacy data the annotations are done for chunks of every N frames (N=4). - This function repeats each chunk's first frame annotation to the rest frames - so that they become 'per-frame' and are ControlNet compatible. - - If the data is already per-frame annotated, then no need to call this. - Args: - annotation_dict (dict): Original annotations annotated every chunk_size frames. - Each value is a list of dicts, where each dict has many - human attributes. Here we only keep keypoint-relevant keys. - total_frames (int): Total number of frames in the video. - frame_indices (list[int]): Indices of the video frames sampled from the the original video. - - Returns: - dict: extended annotations for all frames. - """ - annotated_frame_idxs = sorted(list(annotation_dict.keys())) - chunk_size = annotated_frame_idxs[1] - annotated_frame_idxs[0] - assert chunk_size == 1, "Only support videos that have human annotations for every frame" - - # each person's dict can contain many irrelevant annotations (e.g. seg masks), here we only keep pose kpts - annotation_dict_simpler = { - key: [{k: v for k, v in sub_dict.items() if k in self.control_key_names} for sub_dict in sub_list] - for key, sub_list in annotation_dict.items() - } - annotation_dict_simpler = {idx: annotation_dict_simpler[idx] for idx in frame_indices} - - return annotation_dict_simpler - - def __call__(self, data_dict: dict) -> dict: - """ - human_annotations: loaded human annotation data pickle. One annotation per N frames. - In the past we did N=4; for ControlNet data annotations we will do N=1. - The pickle is a dict of annotated frames. - The key is the frame number. For each frame, as there can be multiple people, we maintain a list of per-person - dicts. Example: - { - 0: [ - {'body-keypoints': , 'hand-keypoints': }, - {'body-keypoints': , 'hand-keypoints': }, - ], # frame 0, 2 people - 4: [ - {'body-keypoints': , 'hand-keypoints': }, - {'body-keypoints': , 'hand-keypoints': }, - {'body-keypoints': , 'hand-keypoints': }, - ], # frame 4, 3 people - ... - } - Note that for the same person, their idx in the per-frame list isn't guaranteed to be consistent. - """ - if "control_input_keypoint" in data_dict: - # already processed - log.info( - f"control_input_human_kpts already processed, shape={data_dict['control_input_keypoint'].shape}, dtype={data_dict['control_input_keypoint'].dtype}, value range: {data_dict['control_input_keypoint'].min()}, {data_dict['control_input_keypoint'].max()}" - ) - return data_dict - - human_annotations = data_dict.pop("keypoint") - frames = data_dict["video"] - _, T, H, W = frames.shape - - # the frames here are a randomly sampled (e.g. 121-frame) chunk from the original video - # so we need to accordingly only use the human annotations of the sampled frames. - frame_start = data_dict["frame_start"] - frame_end = data_dict["frame_end"] - frame_indices = np.arange(frame_start, frame_end).tolist() - assert ( - len(frame_indices) == T - ), f"frame_indices length {len(frame_indices)} != T {T}, likely due to video decoder using different fps, i.e. sample with stride. Need to return frame indices from video decoder." - - try: - # same dict format as `human_annotations` but now every frame has an annotation - kpts_nparray_dict = self.get_kpts_from_annotations(human_annotations, T, frame_indices) - except ValueError as e: - log.error(f"Error in loading kpts from annotated data: {e}") - kpts_nparray_dict = {} - raise e - - try: - # Colored human keypoints plotted on black background. All persons in the same frame are plotted together. - # np.array of shape: [C, T, H, W]. - kpts_cond_video = self.plot_kpt_video( - kpts_nparray_dict, - H, - W, - kpt_thr=self.kpt_thr, - openpose_format=self.use_openpose_format, - line_width=self.line_width, - ) - except ValueError as e: - log.error(f"Error in plot_kpt_video: {e}") - kpts_cond_video = np.zeros_like(frames) - - key_out = self.output_keys[0] - - data_dict[key_out] = torch.from_numpy(kpts_cond_video) - return data_dict - - -class AddControlInputUpscale(Augmentor): - """ - Add control input to the data dictionary. control input are expanded to 3-channels - steps to add new items: modify this file, configs/conditioner.py, conditioner.py - """ - - def __init__( - self, - input_keys: list, - output_keys: Optional[list] = ["control_input_upscale"], - args: Optional[dict] = None, - use_random: bool = True, - preset_strength="medium", - **kwargs, - ) -> None: - super().__init__(input_keys, output_keys, args) - self.use_random = use_random - self.preset_strength = preset_strength - - def __call__(self, data_dict: dict, target_size: tuple = None) -> dict: - if "control_input_upscale" in data_dict: - # already processed - return data_dict - key_img = self.input_keys[1] - key_out = self.output_keys[0] - frames = data_dict[key_img] - frames = np.array(frames) # CTHW - is_image = len(frames.shape) < 4 - if is_image: - frames = frames.transpose((2, 0, 1))[:, None] - h, w = frames.shape[-2:] - frames = torch.from_numpy(frames.transpose(1, 0, 2, 3)) # TCHW - - if "__url__" in data_dict and "aspect_ratio" in data_dict["__url__"].meta.opts: - aspect_ratio = data_dict["__url__"].meta.opts["aspect_ratio"] - elif "aspect_ratio" in data_dict: # Non-webdataset format - aspect_ratio = data_dict["aspect_ratio"] - else: - aspect_ratio = "16,9" - - # Define the crop size - RES_SIZE_INFO = IMAGE_RES_SIZE_INFO if is_image else VIDEO_RES_SIZE_INFO - crop_width, crop_height = RES_SIZE_INFO["720"][aspect_ratio] - - if self.use_random: # always on for training, always off for inference - # During training, randomly crop a patch, then randomly downsize the video and resize it back. - # Determine a random crop location - top = torch.randint(0, max(0, h - crop_height) + 1, (1,)).item() - left = torch.randint(0, max(0, w - crop_width) + 1, (1,)).item() - cropped_frames = frames[:, :, top : top + crop_height, left : left + crop_width] - - # Randomly downsample the video - # for 360p, 720p, 1080p -> 4k - scaler = np.random.choice([1 / 6, 1 / 3, 0.5], p=[0.3, 0.5, 0.2]) - small_crop_width = int(crop_width * scaler) - small_crop_height = int(crop_height * scaler) - resized_frames = transforms_F.resize( - cropped_frames, - size=(small_crop_height, small_crop_width), - interpolation=transforms_F.InterpolationMode.BICUBIC, - antialias=True, - ) - # Upsample to target size - resized_frames = transforms_F.resize( - resized_frames, - size=(crop_height, crop_width), - interpolation=transforms_F.InterpolationMode.BILINEAR, - ) - else: - if target_size is None: # for validation - # During validation, center crop a patch, then resize to target size. - if self.preset_strength == "low": - scaler = 0.5 - elif self.preset_strength == "medium": - scaler = 1 / 3 - else: - scaler = 1 / 6 - small_crop_width = int(crop_width * scaler) - small_crop_height = int(crop_height * scaler) - - # Center crop during inference - top = (h - small_crop_height) // 2 - left = (w - small_crop_width) // 2 - - # Perform the crop - frames = frames[:, :, top : top + small_crop_height, left : left + small_crop_width] - # Upsample to target size - resized_frames = transforms_F.resize( - frames, - size=(crop_height, crop_width), - interpolation=transforms_F.InterpolationMode.BILINEAR, - ) - else: # for inference - # During inference, directly resize to target size. - new_h, new_w = target_size - resized_frames = transforms_F.resize( - frames, - size=(new_h, new_w), - interpolation=transforms_F.InterpolationMode.BILINEAR, - ) - cropped_frames = resized_frames - - resized_frames = resized_frames.permute(1, 0, 2, 3).contiguous() # CTHW - cropped_frames = cropped_frames.permute(1, 0, 2, 3).contiguous() # CTHW - - if is_image: - resized_frames = resized_frames[:, 0] - cropped_frames = cropped_frames[:, 0] - data_dict[key_out] = resized_frames - data_dict[key_img] = cropped_frames - return data_dict - - -if __name__ == "__main__": - import sys - - from cosmos_transfer1.auxiliary.guardrail.common.io_utils import save_video - from cosmos_transfer1.diffusion.config.transfer.augmentors import BilateralOnlyBlurAugmentorConfig - from cosmos_transfer1.diffusion.inference.inference_utils import read_video_or_image_into_frames_BCTHW - - path_in = sys.argv[1] - - def main(input_file_path: str) -> None: - max_length = 10 - video_input = read_video_or_image_into_frames_BCTHW(input_file_path, normalize=False)[0, :, :max_length] - C, T, H, W = video_input.shape - blur_processes = { - "bilateral": BilateralOnlyBlurAugmentorConfig, - } - for blur_name, blur_process in blur_processes.items(): - for preset_strength in ["low", "medium", "high"]: - process = get_augmentor_for_eval( - "video", - "control_input_vis", - preset_blur_strength=preset_strength, - blur_config=blur_process[preset_strength], - ) - output = process({"video": video_input}) - output = output["control_input_vis"].numpy().transpose((1, 2, 3, 0)) - - output_file_path = f"{input_file_path[:-4]}_{blur_name}_{preset_strength}.mp4" - save_video( - frames=output, - fps=5, - filepath=output_file_path, - ) - - print(f"Output video saved as {output_file_path}") - - main(path_in) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/guided_filter.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/guided_filter.py deleted file mode 100644 index 248cd038b61ff82fb125ff23c2da191ce951a73e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/guided_filter.py +++ /dev/null @@ -1,296 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- -# @package guided_filter.core.filters -# -# Implementation of guided filter. -# * GuidedFilter: Original guided filter. -# * FastGuidedFilter: Fast version of the guided filter. -# @author tody -# @date 2015/08/26 - -import cv2 -import numpy as np - - -# Convert image into float32 type. -def to32F(img): - if img.dtype == np.float32: - return img - return (1.0 / 255.0) * np.float32(img) - - -# Convert image into uint8 type. -def to8U(img): - if img.dtype == np.uint8: - return img - return np.clip(np.uint8(255.0 * img), 0, 255) - - -# Return if the input image is gray or not. -def _isGray(I): - return len(I.shape) == 2 - - -# Return down sampled image. -# @param scale (w/s, h/s) image will be created. -# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. -def _downSample(I, scale=4, shape=None): - if shape is not None: - h, w = shape - return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST) - - h, w = I.shape[:2] - return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST) - - -# Return up sampled image. -# @param scale (w*s, h*s) image will be created. -# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. -def _upSample(I, scale=2, shape=None): - if shape is not None: - h, w = shape - return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR) - - h, w = I.shape[:2] - return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) - - -# Fast guide filter. -class FastGuidedFilter: - # Constructor. - # @param I Input guidance image. Color or gray. - # @param radius Radius of Guided Filter. - # @param epsilon Regularization term of Guided Filter. - # @param scale Down sampled scale. - def __init__(self, I, radius=5, epsilon=0.4, scale=4): - I_32F = to32F(I) - self._I = I_32F - h, w = I.shape[:2] - - I_sub = _downSample(I_32F, scale) - - self._I_sub = I_sub - radius = int(radius / scale) - - if _isGray(I): - self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon) - else: - self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon) - - # Apply filter for the input image. - # @param p Input image for the filtering. - def filter(self, p): - p_32F = to32F(p) - shape_original = p.shape[:2] - - p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2]) - - if _isGray(p_sub): - return self._filterGray(p_sub, shape_original) - - cs = p.shape[2] - q = np.array(p_32F) - - for ci in range(cs): - q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original) - return to8U(q) - - def _filterGray(self, p_sub, shape_original): - ab_sub = self._guided_filter._computeCoefficients(p_sub) - ab = [_upSample(abi, shape=shape_original) for abi in ab_sub] - return self._guided_filter._computeOutput(ab, self._I) - - -# Guide filter. -class GuidedFilter: - # Constructor. - # @param I Input guidance image. Color or gray. - # @param radius Radius of Guided Filter. - # @param epsilon Regularization term of Guided Filter. - def __init__(self, I, radius=5, epsilon=0.4): - I_32F = to32F(I) - - if _isGray(I): - self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon) - else: - self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon) - - # Apply filter for the input image. - # @param p Input image for the filtering. - def filter(self, p): - return to8U(self._guided_filter.filter(p)) - - -# Common parts of guided filter. -# -# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor. -# Based on guided_filter._computeCoefficients, guided_filter._computeOutput, -# GuidedFilterCommon.filter computes filtered image for color and gray. -class GuidedFilterCommon: - def __init__(self, guided_filter): - self._guided_filter = guided_filter - - # Apply filter for the input image. - # @param p Input image for the filtering. - def filter(self, p): - p_32F = to32F(p) - if _isGray(p_32F): - return self._filterGray(p_32F) - - cs = p.shape[2] - q = np.array(p_32F) - - for ci in range(cs): - q[:, :, ci] = self._filterGray(p_32F[:, :, ci]) - return q - - def _filterGray(self, p): - ab = self._guided_filter._computeCoefficients(p) - return self._guided_filter._computeOutput(ab, self._guided_filter._I) - - -# Guided filter for gray guidance image. -class GuidedFilterGray: - # @param I Input gray guidance image. - # @param radius Radius of Guided Filter. - # @param epsilon Regularization term of Guided Filter. - def __init__(self, I, radius=5, epsilon=0.4): - self._radius = 2 * radius + 1 - self._epsilon = epsilon - self._I = to32F(I) - self._initFilter() - self._filter_common = GuidedFilterCommon(self) - - # Apply filter for the input image. - # @param p Input image for the filtering. - def filter(self, p): - return self._filter_common.filter(p) - - def _initFilter(self): - I = self._I - r = self._radius - self._I_mean = cv2.blur(I, (r, r)) - I_mean_sq = cv2.blur(I**2, (r, r)) - self._I_var = I_mean_sq - self._I_mean**2 - - def _computeCoefficients(self, p): - r = self._radius - p_mean = cv2.blur(p, (r, r)) - p_cov = p_mean - self._I_mean * p_mean - a = p_cov / (self._I_var + self._epsilon) - b = p_mean - a * self._I_mean - a_mean = cv2.blur(a, (r, r)) - b_mean = cv2.blur(b, (r, r)) - return a_mean, b_mean - - def _computeOutput(self, ab, I): - a_mean, b_mean = ab - return a_mean * I + b_mean - - -# Guided filter for color guidance image. -class GuidedFilterColor: - # @param I Input color guidance image. - # @param radius Radius of Guided Filter. - # @param epsilon Regularization term of Guided Filter. - def __init__(self, I, radius=5, epsilon=0.2): - self._radius = 2 * radius + 1 - self._epsilon = epsilon - self._I = to32F(I) - self._initFilter() - self._filter_common = GuidedFilterCommon(self) - - # Apply filter for the input image. - # @param p Input image for the filtering. - def filter(self, p): - return self._filter_common.filter(p) - - def _initFilter(self): - I = self._I - r = self._radius - eps = self._epsilon - - Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] - - self._Ir_mean = cv2.blur(Ir, (r, r)) - self._Ig_mean = cv2.blur(Ig, (r, r)) - self._Ib_mean = cv2.blur(Ib, (r, r)) - - Irr_var = cv2.blur(Ir**2, (r, r)) - self._Ir_mean**2 + eps - Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean - Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean - Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps - Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean - Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps - - Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var - Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var - Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var - Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var - Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var - Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var - - I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var - Irr_inv /= I_cov - Irg_inv /= I_cov - Irb_inv /= I_cov - Igg_inv /= I_cov - Igb_inv /= I_cov - Ibb_inv /= I_cov - - self._Irr_inv = Irr_inv - self._Irg_inv = Irg_inv - self._Irb_inv = Irb_inv - self._Igg_inv = Igg_inv - self._Igb_inv = Igb_inv - self._Ibb_inv = Ibb_inv - - def _computeCoefficients(self, p): - r = self._radius - I = self._I - Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] - - p_mean = cv2.blur(p, (r, r)) - - Ipr_mean = cv2.blur(Ir * p, (r, r)) - Ipg_mean = cv2.blur(Ig * p, (r, r)) - Ipb_mean = cv2.blur(Ib * p, (r, r)) - - Ipr_cov = Ipr_mean - self._Ir_mean * p_mean - Ipg_cov = Ipg_mean - self._Ig_mean * p_mean - Ipb_cov = Ipb_mean - self._Ib_mean * p_mean - - ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov - ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov - ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov - b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean - - ar_mean = cv2.blur(ar, (r, r)) - ag_mean = cv2.blur(ag, (r, r)) - ab_mean = cv2.blur(ab, (r, r)) - b_mean = cv2.blur(b, (r, r)) - - return ar_mean, ag_mean, ab_mean, b_mean - - def _computeOutput(self, ab, I): - ar_mean, ag_mean, ab_mean, b_mean = ab - - Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] - - q = ar_mean * Ir + ag_mean * Ig + ab_mean * Ib + b_mean - - return q diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/human_keypoint_utils.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/human_keypoint_utils.py deleted file mode 100644 index 17f03ea4889b403a81389a127e5f78c9eaeede37..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/human_keypoint_utils.py +++ /dev/null @@ -1,451 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Adapted from rtmlib: - https://raw.githubusercontent.com/Tau-J/rtmlib/refs/heads/main/rtmlib/visualization/skeleton/coco133.py - https://github.com/Tau-J/rtmlib/blob/main/rtmlib/tools/pose_estimation/post_processings.py -""" -import numpy as np - -coco_wholebody_133_skeleton = dict( - name="coco133", - keypoint_info={ - 0: dict(name="nose", id=0, color=[51, 153, 255], swap=""), - 1: dict(name="left_eye", id=1, color=[51, 153, 255], swap="right_eye"), - 2: dict(name="right_eye", id=2, color=[51, 153, 255], swap="left_eye"), - 3: dict(name="left_ear", id=3, color=[51, 153, 255], swap="right_ear"), - 4: dict(name="right_ear", id=4, color=[51, 153, 255], swap="left_ear"), - 5: dict(name="left_shoulder", id=5, color=[0, 255, 0], swap="right_shoulder"), - 6: dict(name="right_shoulder", id=6, color=[255, 128, 0], swap="left_shoulder"), - 7: dict(name="left_elbow", id=7, color=[0, 255, 0], swap="right_elbow"), - 8: dict(name="right_elbow", id=8, color=[255, 128, 0], swap="left_elbow"), - 9: dict(name="left_wrist", id=9, color=[0, 255, 0], swap="right_wrist"), - 10: dict(name="right_wrist", id=10, color=[255, 128, 0], swap="left_wrist"), - 11: dict(name="left_hip", id=11, color=[0, 255, 0], swap="right_hip"), - 12: dict(name="right_hip", id=12, color=[255, 128, 0], swap="left_hip"), - 13: dict(name="left_knee", id=13, color=[0, 255, 0], swap="right_knee"), - 14: dict(name="right_knee", id=14, color=[255, 128, 0], swap="left_knee"), - 15: dict(name="left_ankle", id=15, color=[0, 255, 0], swap="right_ankle"), - 16: dict(name="right_ankle", id=16, color=[255, 128, 0], swap="left_ankle"), - 17: dict(name="left_big_toe", id=17, color=[255, 128, 0], swap="right_big_toe"), - 18: dict(name="left_small_toe", id=18, color=[255, 128, 0], swap="right_small_toe"), - 19: dict(name="left_heel", id=19, color=[255, 128, 0], swap="right_heel"), - 20: dict(name="right_big_toe", id=20, color=[255, 128, 0], swap="left_big_toe"), - 21: dict(name="right_small_toe", id=21, color=[255, 128, 0], swap="left_small_toe"), - 22: dict(name="right_heel", id=22, color=[255, 128, 0], swap="left_heel"), - 23: dict(name="face-0", id=23, color=[255, 255, 255], swap="face-16"), - 24: dict(name="face-1", id=24, color=[255, 255, 255], swap="face-15"), - 25: dict(name="face-2", id=25, color=[255, 255, 255], swap="face-14"), - 26: dict(name="face-3", id=26, color=[255, 255, 255], swap="face-13"), - 27: dict(name="face-4", id=27, color=[255, 255, 255], swap="face-12"), - 28: dict(name="face-5", id=28, color=[255, 255, 255], swap="face-11"), - 29: dict(name="face-6", id=29, color=[255, 255, 255], swap="face-10"), - 30: dict(name="face-7", id=30, color=[255, 255, 255], swap="face-9"), - 31: dict(name="face-8", id=31, color=[255, 255, 255], swap=""), - 32: dict(name="face-9", id=32, color=[255, 255, 255], swap="face-7"), - 33: dict(name="face-10", id=33, color=[255, 255, 255], swap="face-6"), - 34: dict(name="face-11", id=34, color=[255, 255, 255], swap="face-5"), - 35: dict(name="face-12", id=35, color=[255, 255, 255], swap="face-4"), - 36: dict(name="face-13", id=36, color=[255, 255, 255], swap="face-3"), - 37: dict(name="face-14", id=37, color=[255, 255, 255], swap="face-2"), - 38: dict(name="face-15", id=38, color=[255, 255, 255], swap="face-1"), - 39: dict(name="face-16", id=39, color=[255, 255, 255], swap="face-0"), - 40: dict(name="face-17", id=40, color=[255, 255, 255], swap="face-26"), - 41: dict(name="face-18", id=41, color=[255, 255, 255], swap="face-25"), - 42: dict(name="face-19", id=42, color=[255, 255, 255], swap="face-24"), - 43: dict(name="face-20", id=43, color=[255, 255, 255], swap="face-23"), - 44: dict(name="face-21", id=44, color=[255, 255, 255], swap="face-22"), - 45: dict(name="face-22", id=45, color=[255, 255, 255], swap="face-21"), - 46: dict(name="face-23", id=46, color=[255, 255, 255], swap="face-20"), - 47: dict(name="face-24", id=47, color=[255, 255, 255], swap="face-19"), - 48: dict(name="face-25", id=48, color=[255, 255, 255], swap="face-18"), - 49: dict(name="face-26", id=49, color=[255, 255, 255], swap="face-17"), - 50: dict(name="face-27", id=50, color=[255, 255, 255], swap=""), - 51: dict(name="face-28", id=51, color=[255, 255, 255], swap=""), - 52: dict(name="face-29", id=52, color=[255, 255, 255], swap=""), - 53: dict(name="face-30", id=53, color=[255, 255, 255], swap=""), - 54: dict(name="face-31", id=54, color=[255, 255, 255], swap="face-35"), - 55: dict(name="face-32", id=55, color=[255, 255, 255], swap="face-34"), - 56: dict(name="face-33", id=56, color=[255, 255, 255], swap=""), - 57: dict(name="face-34", id=57, color=[255, 255, 255], swap="face-32"), - 58: dict(name="face-35", id=58, color=[255, 255, 255], swap="face-31"), - 59: dict(name="face-36", id=59, color=[255, 255, 255], swap="face-45"), - 60: dict(name="face-37", id=60, color=[255, 255, 255], swap="face-44"), - 61: dict(name="face-38", id=61, color=[255, 255, 255], swap="face-43"), - 62: dict(name="face-39", id=62, color=[255, 255, 255], swap="face-42"), - 63: dict(name="face-40", id=63, color=[255, 255, 255], swap="face-47"), - 64: dict(name="face-41", id=64, color=[255, 255, 255], swap="face-46"), - 65: dict(name="face-42", id=65, color=[255, 255, 255], swap="face-39"), - 66: dict(name="face-43", id=66, color=[255, 255, 255], swap="face-38"), - 67: dict(name="face-44", id=67, color=[255, 255, 255], swap="face-37"), - 68: dict(name="face-45", id=68, color=[255, 255, 255], swap="face-36"), - 69: dict(name="face-46", id=69, color=[255, 255, 255], swap="face-41"), - 70: dict(name="face-47", id=70, color=[255, 255, 255], swap="face-40"), - 71: dict(name="face-48", id=71, color=[255, 255, 255], swap="face-54"), - 72: dict(name="face-49", id=72, color=[255, 255, 255], swap="face-53"), - 73: dict(name="face-50", id=73, color=[255, 255, 255], swap="face-52"), - 74: dict(name="face-51", id=74, color=[255, 255, 255], swap=""), - 75: dict(name="face-52", id=75, color=[255, 255, 255], swap="face-50"), - 76: dict(name="face-53", id=76, color=[255, 255, 255], swap="face-49"), - 77: dict(name="face-54", id=77, color=[255, 255, 255], swap="face-48"), - 78: dict(name="face-55", id=78, color=[255, 255, 255], swap="face-59"), - 79: dict(name="face-56", id=79, color=[255, 255, 255], swap="face-58"), - 80: dict(name="face-57", id=80, color=[255, 255, 255], swap=""), - 81: dict(name="face-58", id=81, color=[255, 255, 255], swap="face-56"), - 82: dict(name="face-59", id=82, color=[255, 255, 255], swap="face-55"), - 83: dict(name="face-60", id=83, color=[255, 255, 255], swap="face-64"), - 84: dict(name="face-61", id=84, color=[255, 255, 255], swap="face-63"), - 85: dict(name="face-62", id=85, color=[255, 255, 255], swap=""), - 86: dict(name="face-63", id=86, color=[255, 255, 255], swap="face-61"), - 87: dict(name="face-64", id=87, color=[255, 255, 255], swap="face-60"), - 88: dict(name="face-65", id=88, color=[255, 255, 255], swap="face-67"), - 89: dict(name="face-66", id=89, color=[255, 255, 255], swap=""), - 90: dict(name="face-67", id=90, color=[255, 255, 255], swap="face-65"), - 91: dict(name="left_hand_root", id=91, color=[255, 255, 255], swap="right_hand_root"), - 92: dict(name="left_thumb1", id=92, color=[255, 128, 0], swap="right_thumb1"), - 93: dict(name="left_thumb2", id=93, color=[255, 128, 0], swap="right_thumb2"), - 94: dict(name="left_thumb3", id=94, color=[255, 128, 0], swap="right_thumb3"), - 95: dict(name="left_thumb4", id=95, color=[255, 128, 0], swap="right_thumb4"), - 96: dict(name="left_forefinger1", id=96, color=[255, 153, 255], swap="right_forefinger1"), - 97: dict(name="left_forefinger2", id=97, color=[255, 153, 255], swap="right_forefinger2"), - 98: dict(name="left_forefinger3", id=98, color=[255, 153, 255], swap="right_forefinger3"), - 99: dict(name="left_forefinger4", id=99, color=[255, 153, 255], swap="right_forefinger4"), - 100: dict(name="left_middle_finger1", id=100, color=[102, 178, 255], swap="right_middle_finger1"), - 101: dict(name="left_middle_finger2", id=101, color=[102, 178, 255], swap="right_middle_finger2"), - 102: dict(name="left_middle_finger3", id=102, color=[102, 178, 255], swap="right_middle_finger3"), - 103: dict(name="left_middle_finger4", id=103, color=[102, 178, 255], swap="right_middle_finger4"), - 104: dict(name="left_ring_finger1", id=104, color=[255, 51, 51], swap="right_ring_finger1"), - 105: dict(name="left_ring_finger2", id=105, color=[255, 51, 51], swap="right_ring_finger2"), - 106: dict(name="left_ring_finger3", id=106, color=[255, 51, 51], swap="right_ring_finger3"), - 107: dict(name="left_ring_finger4", id=107, color=[255, 51, 51], swap="right_ring_finger4"), - 108: dict(name="left_pinky_finger1", id=108, color=[0, 255, 0], swap="right_pinky_finger1"), - 109: dict(name="left_pinky_finger2", id=109, color=[0, 255, 0], swap="right_pinky_finger2"), - 110: dict(name="left_pinky_finger3", id=110, color=[0, 255, 0], swap="right_pinky_finger3"), - 111: dict(name="left_pinky_finger4", id=111, color=[0, 255, 0], swap="right_pinky_finger4"), - 112: dict(name="right_hand_root", id=112, color=[255, 255, 255], swap="left_hand_root"), - 113: dict(name="right_thumb1", id=113, color=[255, 128, 0], swap="left_thumb1"), - 114: dict(name="right_thumb2", id=114, color=[255, 128, 0], swap="left_thumb2"), - 115: dict(name="right_thumb3", id=115, color=[255, 128, 0], swap="left_thumb3"), - 116: dict(name="right_thumb4", id=116, color=[255, 128, 0], swap="left_thumb4"), - 117: dict(name="right_forefinger1", id=117, color=[255, 153, 255], swap="left_forefinger1"), - 118: dict(name="right_forefinger2", id=118, color=[255, 153, 255], swap="left_forefinger2"), - 119: dict(name="right_forefinger3", id=119, color=[255, 153, 255], swap="left_forefinger3"), - 120: dict(name="right_forefinger4", id=120, color=[255, 153, 255], swap="left_forefinger4"), - 121: dict(name="right_middle_finger1", id=121, color=[102, 178, 255], swap="left_middle_finger1"), - 122: dict(name="right_middle_finger2", id=122, color=[102, 178, 255], swap="left_middle_finger2"), - 123: dict(name="right_middle_finger3", id=123, color=[102, 178, 255], swap="left_middle_finger3"), - 124: dict(name="right_middle_finger4", id=124, color=[102, 178, 255], swap="left_middle_finger4"), - 125: dict(name="right_ring_finger1", id=125, color=[255, 51, 51], swap="left_ring_finger1"), - 126: dict(name="right_ring_finger2", id=126, color=[255, 51, 51], swap="left_ring_finger2"), - 127: dict(name="right_ring_finger3", id=127, color=[255, 51, 51], swap="left_ring_finger3"), - 128: dict(name="right_ring_finger4", id=128, color=[255, 51, 51], swap="left_ring_finger4"), - 129: dict(name="right_pinky_finger1", id=129, color=[0, 255, 0], swap="left_pinky_finger1"), - 130: dict(name="right_pinky_finger2", id=130, color=[0, 255, 0], swap="left_pinky_finger2"), - 131: dict(name="right_pinky_finger3", id=131, color=[0, 255, 0], swap="left_pinky_finger3"), - 132: dict(name="right_pinky_finger4", id=132, color=[0, 255, 0], swap="left_pinky_finger4"), - }, - skeleton_info={ - 0: dict(link=("left_ankle", "left_knee"), id=0, color=[0, 255, 0]), - 1: dict(link=("left_knee", "left_hip"), id=1, color=[0, 255, 0]), - 2: dict(link=("right_ankle", "right_knee"), id=2, color=[255, 128, 0]), - 3: dict(link=("right_knee", "right_hip"), id=3, color=[255, 128, 0]), - 4: dict(link=("left_hip", "right_hip"), id=4, color=[51, 153, 255]), - 5: dict(link=("left_shoulder", "left_hip"), id=5, color=[51, 153, 255]), - 6: dict(link=("right_shoulder", "right_hip"), id=6, color=[51, 153, 255]), - 7: dict(link=("left_shoulder", "right_shoulder"), id=7, color=[51, 153, 255]), - 8: dict(link=("left_shoulder", "left_elbow"), id=8, color=[0, 255, 0]), - 9: dict(link=("right_shoulder", "right_elbow"), id=9, color=[255, 128, 0]), - 10: dict(link=("left_elbow", "left_wrist"), id=10, color=[0, 255, 0]), - 11: dict(link=("right_elbow", "right_wrist"), id=11, color=[255, 128, 0]), - 12: dict(link=("left_eye", "right_eye"), id=12, color=[51, 153, 255]), - 13: dict(link=("nose", "left_eye"), id=13, color=[51, 153, 255]), - 14: dict(link=("nose", "right_eye"), id=14, color=[51, 153, 255]), - 15: dict(link=("left_eye", "left_ear"), id=15, color=[51, 153, 255]), - 16: dict(link=("right_eye", "right_ear"), id=16, color=[51, 153, 255]), - 17: dict(link=("left_ear", "left_shoulder"), id=17, color=[51, 153, 255]), - 18: dict(link=("right_ear", "right_shoulder"), id=18, color=[51, 153, 255]), - 19: dict(link=("left_ankle", "left_big_toe"), id=19, color=[0, 255, 0]), - 20: dict(link=("left_ankle", "left_small_toe"), id=20, color=[0, 255, 0]), - 21: dict(link=("left_ankle", "left_heel"), id=21, color=[0, 255, 0]), - 22: dict(link=("right_ankle", "right_big_toe"), id=22, color=[255, 128, 0]), - 23: dict(link=("right_ankle", "right_small_toe"), id=23, color=[255, 128, 0]), - 24: dict(link=("right_ankle", "right_heel"), id=24, color=[255, 128, 0]), - 25: dict(link=("left_hand_root", "left_thumb1"), id=25, color=[255, 128, 0]), - 26: dict(link=("left_thumb1", "left_thumb2"), id=26, color=[255, 128, 0]), - 27: dict(link=("left_thumb2", "left_thumb3"), id=27, color=[255, 128, 0]), - 28: dict(link=("left_thumb3", "left_thumb4"), id=28, color=[255, 128, 0]), - 29: dict(link=("left_hand_root", "left_forefinger1"), id=29, color=[255, 153, 255]), - 30: dict(link=("left_forefinger1", "left_forefinger2"), id=30, color=[255, 153, 255]), - 31: dict(link=("left_forefinger2", "left_forefinger3"), id=31, color=[255, 153, 255]), - 32: dict(link=("left_forefinger3", "left_forefinger4"), id=32, color=[255, 153, 255]), - 33: dict(link=("left_hand_root", "left_middle_finger1"), id=33, color=[102, 178, 255]), - 34: dict(link=("left_middle_finger1", "left_middle_finger2"), id=34, color=[102, 178, 255]), - 35: dict(link=("left_middle_finger2", "left_middle_finger3"), id=35, color=[102, 178, 255]), - 36: dict(link=("left_middle_finger3", "left_middle_finger4"), id=36, color=[102, 178, 255]), - 37: dict(link=("left_hand_root", "left_ring_finger1"), id=37, color=[255, 51, 51]), - 38: dict(link=("left_ring_finger1", "left_ring_finger2"), id=38, color=[255, 51, 51]), - 39: dict(link=("left_ring_finger2", "left_ring_finger3"), id=39, color=[255, 51, 51]), - 40: dict(link=("left_ring_finger3", "left_ring_finger4"), id=40, color=[255, 51, 51]), - 41: dict(link=("left_hand_root", "left_pinky_finger1"), id=41, color=[0, 255, 0]), - 42: dict(link=("left_pinky_finger1", "left_pinky_finger2"), id=42, color=[0, 255, 0]), - 43: dict(link=("left_pinky_finger2", "left_pinky_finger3"), id=43, color=[0, 255, 0]), - 44: dict(link=("left_pinky_finger3", "left_pinky_finger4"), id=44, color=[0, 255, 0]), - 45: dict(link=("right_hand_root", "right_thumb1"), id=45, color=[255, 128, 0]), - 46: dict(link=("right_thumb1", "right_thumb2"), id=46, color=[255, 128, 0]), - 47: dict(link=("right_thumb2", "right_thumb3"), id=47, color=[255, 128, 0]), - 48: dict(link=("right_thumb3", "right_thumb4"), id=48, color=[255, 128, 0]), - 49: dict(link=("right_hand_root", "right_forefinger1"), id=49, color=[255, 153, 255]), - 50: dict(link=("right_forefinger1", "right_forefinger2"), id=50, color=[255, 153, 255]), - 51: dict(link=("right_forefinger2", "right_forefinger3"), id=51, color=[255, 153, 255]), - 52: dict(link=("right_forefinger3", "right_forefinger4"), id=52, color=[255, 153, 255]), - 53: dict(link=("right_hand_root", "right_middle_finger1"), id=53, color=[102, 178, 255]), - 54: dict(link=("right_middle_finger1", "right_middle_finger2"), id=54, color=[102, 178, 255]), - 55: dict(link=("right_middle_finger2", "right_middle_finger3"), id=55, color=[102, 178, 255]), - 56: dict(link=("right_middle_finger3", "right_middle_finger4"), id=56, color=[102, 178, 255]), - 57: dict(link=("right_hand_root", "right_ring_finger1"), id=57, color=[255, 51, 51]), - 58: dict(link=("right_ring_finger1", "right_ring_finger2"), id=58, color=[255, 51, 51]), - 59: dict(link=("right_ring_finger2", "right_ring_finger3"), id=59, color=[255, 51, 51]), - 60: dict(link=("right_ring_finger3", "right_ring_finger4"), id=60, color=[255, 51, 51]), - 61: dict(link=("right_hand_root", "right_pinky_finger1"), id=61, color=[0, 255, 0]), - 62: dict(link=("right_pinky_finger1", "right_pinky_finger2"), id=62, color=[0, 255, 0]), - 63: dict(link=("right_pinky_finger2", "right_pinky_finger3"), id=63, color=[0, 255, 0]), - 64: dict(link=("right_pinky_finger3", "right_pinky_finger4"), id=64, color=[0, 255, 0]), - }, -) - - -# openpose 134 keypoints, can be converted from the coco-133 keypoints. -openpose134_skeleton = dict( - name="openpose134", - keypoint_info={ - 0: dict(name="nose", id=0, color=[255, 0, 0], swap=""), - 1: dict(name="neck", id=1, color=[255, 85, 0], swap=""), - 2: dict(name="right_shoulder", id=2, color=[255, 170, 0], swap="left_shoulder"), - 3: dict(name="right_elbow", id=3, color=[255, 255, 0], swap="left_elbow"), - 4: dict(name="right_wrist", id=4, color=[170, 255, 0], swap="left_wrist"), - 5: dict(name="left_shoulder", id=5, color=[85, 255, 0], swap="right_shoulder"), - 6: dict(name="left_elbow", id=6, color=[0, 255, 0], swap="right_elbow"), - 7: dict(name="left_wrist", id=7, color=[0, 255, 85], swap="right_wrist"), - 8: dict(name="right_hip", id=8, color=[0, 255, 170], swap="left_hip"), - 9: dict(name="right_knee", id=9, color=[0, 255, 255], swap="left_knee"), - 10: dict(name="right_ankle", id=10, color=[0, 170, 255], swap="left_ankle"), - 11: dict(name="left_hip", id=11, color=[0, 85, 255], swap="right_hip"), - 12: dict(name="left_knee", id=12, color=[0, 0, 255], swap="right_knee"), - 13: dict(name="left_ankle", id=13, color=[85, 0, 255], swap="right_ankle"), - 14: dict(name="right_eye", id=14, color=[170, 0, 255], swap="left_eye"), - 15: dict(name="left_eye", id=15, color=[255, 0, 255], swap="right_eye"), - 16: dict(name="right_ear", id=16, color=[255, 0, 170], swap="left_ear"), - 17: dict(name="left_ear", id=17, color=[255, 0, 85], swap="right_ear"), - 18: dict(name="left_big_toe", id=17, color=[0, 0, 0], swap="right_big_toe"), - 19: dict(name="left_small_toe", id=18, color=[0, 0, 0], swap="right_small_toe"), - 20: dict(name="left_heel", id=19, color=[0, 0, 0], swap="right_heel"), - 21: dict(name="right_big_toe", id=20, color=[0, 0, 0], swap="left_big_toe"), - 22: dict(name="right_small_toe", id=21, color=[0, 0, 0], swap="left_small_toe"), - 23: dict(name="right_heel", id=22, color=[0, 0, 0], swap="left_heel"), - 24: dict(name="face-0", id=23, color=[255, 255, 255], swap="face-16"), - 25: dict(name="face-1", id=24, color=[255, 255, 255], swap="face-15"), - 26: dict(name="face-2", id=25, color=[255, 255, 255], swap="face-14"), - 27: dict(name="face-3", id=26, color=[255, 255, 255], swap="face-13"), - 28: dict(name="face-4", id=27, color=[255, 255, 255], swap="face-12"), - 29: dict(name="face-5", id=28, color=[255, 255, 255], swap="face-11"), - 30: dict(name="face-6", id=29, color=[255, 255, 255], swap="face-10"), - 31: dict(name="face-7", id=30, color=[255, 255, 255], swap="face-9"), - 32: dict(name="face-8", id=31, color=[255, 255, 255], swap=""), - 33: dict(name="face-9", id=32, color=[255, 255, 255], swap="face-7"), - 34: dict(name="face-10", id=33, color=[255, 255, 255], swap="face-6"), - 35: dict(name="face-11", id=34, color=[255, 255, 255], swap="face-5"), - 36: dict(name="face-12", id=35, color=[255, 255, 255], swap="face-4"), - 37: dict(name="face-13", id=36, color=[255, 255, 255], swap="face-3"), - 38: dict(name="face-14", id=37, color=[255, 255, 255], swap="face-2"), - 39: dict(name="face-15", id=38, color=[255, 255, 255], swap="face-1"), - 40: dict(name="face-16", id=39, color=[255, 255, 255], swap="face-0"), - 41: dict(name="face-17", id=40, color=[255, 255, 255], swap="face-26"), - 42: dict(name="face-18", id=41, color=[255, 255, 255], swap="face-25"), - 43: dict(name="face-19", id=42, color=[255, 255, 255], swap="face-24"), - 44: dict(name="face-20", id=43, color=[255, 255, 255], swap="face-23"), - 45: dict(name="face-21", id=44, color=[255, 255, 255], swap="face-22"), - 46: dict(name="face-22", id=45, color=[255, 255, 255], swap="face-21"), - 47: dict(name="face-23", id=46, color=[255, 255, 255], swap="face-20"), - 48: dict(name="face-24", id=47, color=[255, 255, 255], swap="face-19"), - 49: dict(name="face-25", id=48, color=[255, 255, 255], swap="face-18"), - 50: dict(name="face-26", id=49, color=[255, 255, 255], swap="face-17"), - 51: dict(name="face-27", id=50, color=[255, 255, 255], swap=""), - 52: dict(name="face-28", id=51, color=[255, 255, 255], swap=""), - 53: dict(name="face-29", id=52, color=[255, 255, 255], swap=""), - 54: dict(name="face-30", id=53, color=[255, 255, 255], swap=""), - 55: dict(name="face-31", id=54, color=[255, 255, 255], swap="face-35"), - 56: dict(name="face-32", id=55, color=[255, 255, 255], swap="face-34"), - 57: dict(name="face-33", id=56, color=[255, 255, 255], swap=""), - 58: dict(name="face-34", id=57, color=[255, 255, 255], swap="face-32"), - 59: dict(name="face-35", id=58, color=[255, 255, 255], swap="face-31"), - 60: dict(name="face-36", id=59, color=[255, 255, 255], swap="face-45"), - 61: dict(name="face-37", id=60, color=[255, 255, 255], swap="face-44"), - 62: dict(name="face-38", id=61, color=[255, 255, 255], swap="face-43"), - 63: dict(name="face-39", id=62, color=[255, 255, 255], swap="face-42"), - 64: dict(name="face-40", id=63, color=[255, 255, 255], swap="face-47"), - 65: dict(name="face-41", id=64, color=[255, 255, 255], swap="face-46"), - 66: dict(name="face-42", id=65, color=[255, 255, 255], swap="face-39"), - 67: dict(name="face-43", id=66, color=[255, 255, 255], swap="face-38"), - 68: dict(name="face-44", id=67, color=[255, 255, 255], swap="face-37"), - 69: dict(name="face-45", id=68, color=[255, 255, 255], swap="face-36"), - 70: dict(name="face-46", id=69, color=[255, 255, 255], swap="face-41"), - 71: dict(name="face-47", id=70, color=[255, 255, 255], swap="face-40"), - 72: dict(name="face-48", id=71, color=[255, 255, 255], swap="face-54"), - 73: dict(name="face-49", id=72, color=[255, 255, 255], swap="face-53"), - 74: dict(name="face-50", id=73, color=[255, 255, 255], swap="face-52"), - 75: dict(name="face-51", id=74, color=[255, 255, 255], swap=""), - 76: dict(name="face-52", id=75, color=[255, 255, 255], swap="face-50"), - 77: dict(name="face-53", id=76, color=[255, 255, 255], swap="face-49"), - 78: dict(name="face-54", id=77, color=[255, 255, 255], swap="face-48"), - 79: dict(name="face-55", id=78, color=[255, 255, 255], swap="face-59"), - 80: dict(name="face-56", id=79, color=[255, 255, 255], swap="face-58"), - 81: dict(name="face-57", id=80, color=[255, 255, 255], swap=""), - 82: dict(name="face-58", id=81, color=[255, 255, 255], swap="face-56"), - 83: dict(name="face-59", id=82, color=[255, 255, 255], swap="face-55"), - 84: dict(name="face-60", id=83, color=[255, 255, 255], swap="face-64"), - 85: dict(name="face-61", id=84, color=[255, 255, 255], swap="face-63"), - 86: dict(name="face-62", id=85, color=[255, 255, 255], swap=""), - 87: dict(name="face-63", id=86, color=[255, 255, 255], swap="face-61"), - 88: dict(name="face-64", id=87, color=[255, 255, 255], swap="face-60"), - 89: dict(name="face-65", id=88, color=[255, 255, 255], swap="face-67"), - 90: dict(name="face-66", id=89, color=[255, 255, 255], swap=""), - 91: dict(name="face-67", id=90, color=[255, 255, 255], swap="face-65"), - 92: dict(name="left_hand_root", id=92, color=[0, 0, 255], swap="right_hand_root"), - 93: dict(name="left_thumb1", id=93, color=[0, 0, 255], swap="right_thumb1"), - 94: dict(name="left_thumb2", id=94, color=[0, 0, 255], swap="right_thumb2"), - 95: dict(name="left_thumb3", id=95, color=[0, 0, 255], swap="right_thumb3"), - 96: dict(name="left_thumb4", id=96, color=[0, 0, 255], swap="right_thumb4"), - 97: dict(name="left_forefinger1", id=97, color=[0, 0, 255], swap="right_forefinger1"), - 98: dict(name="left_forefinger2", id=98, color=[0, 0, 255], swap="right_forefinger2"), - 99: dict(name="left_forefinger3", id=99, color=[0, 0, 255], swap="right_forefinger3"), - 100: dict(name="left_forefinger4", id=100, color=[0, 0, 255], swap="right_forefinger4"), - 101: dict(name="left_middle_finger1", id=101, color=[0, 0, 255], swap="right_middle_finger1"), - 102: dict(name="left_middle_finger2", id=102, color=[0, 0, 255], swap="right_middle_finger2"), - 103: dict(name="left_middle_finger3", id=103, color=[0, 0, 255], swap="right_middle_finger3"), - 104: dict(name="left_middle_finger4", id=104, color=[0, 0, 255], swap="right_middle_finger4"), - 105: dict(name="left_ring_finger1", id=105, color=[0, 0, 255], swap="right_ring_finger1"), - 106: dict(name="left_ring_finger2", id=106, color=[0, 0, 255], swap="right_ring_finger2"), - 107: dict(name="left_ring_finger3", id=107, color=[0, 0, 255], swap="right_ring_finger3"), - 108: dict(name="left_ring_finger4", id=108, color=[0, 0, 255], swap="right_ring_finger4"), - 109: dict(name="left_pinky_finger1", id=109, color=[0, 0, 255], swap="right_pinky_finger1"), - 110: dict(name="left_pinky_finger2", id=110, color=[0, 0, 255], swap="right_pinky_finger2"), - 111: dict(name="left_pinky_finger3", id=111, color=[0, 0, 255], swap="right_pinky_finger3"), - 112: dict(name="left_pinky_finger4", id=112, color=[0, 0, 255], swap="right_pinky_finger4"), - 113: dict(name="right_hand_root", id=113, color=[0, 0, 255], swap="left_hand_root"), - 114: dict(name="right_thumb1", id=114, color=[0, 0, 255], swap="left_thumb1"), - 115: dict(name="right_thumb2", id=115, color=[0, 0, 255], swap="left_thumb2"), - 116: dict(name="right_thumb3", id=116, color=[0, 0, 255], swap="left_thumb3"), - 117: dict(name="right_thumb4", id=117, color=[0, 0, 255], swap="left_thumb4"), - 118: dict(name="right_forefinger1", id=118, color=[0, 0, 255], swap="left_forefinger1"), - 119: dict(name="right_forefinger2", id=119, color=[0, 0, 255], swap="left_forefinger2"), - 120: dict(name="right_forefinger3", id=120, color=[0, 0, 255], swap="left_forefinger3"), - 121: dict(name="right_forefinger4", id=121, color=[0, 0, 255], swap="left_forefinger4"), - 122: dict(name="right_middle_finger1", id=122, color=[0, 0, 255], swap="left_middle_finger1"), - 123: dict(name="right_middle_finger2", id=123, color=[0, 0, 255], swap="left_middle_finger2"), - 124: dict(name="right_middle_finger3", id=124, color=[0, 0, 255], swap="left_middle_finger3"), - 125: dict(name="right_middle_finger4", id=125, color=[0, 0, 255], swap="left_middle_finger4"), - 126: dict(name="right_ring_finger1", id=126, color=[0, 0, 255], swap="left_ring_finger1"), - 127: dict(name="right_ring_finger2", id=127, color=[0, 0, 255], swap="left_ring_finger2"), - 128: dict(name="right_ring_finger3", id=128, color=[0, 0, 255], swap="left_ring_finger3"), - 129: dict(name="right_ring_finger4", id=129, color=[0, 0, 255], swap="left_ring_finger4"), - 130: dict(name="right_pinky_finger1", id=130, color=[0, 0, 255], swap="left_pinky_finger1"), - 131: dict(name="right_pinky_finger2", id=131, color=[0, 0, 255], swap="left_pinky_finger2"), - 132: dict(name="right_pinky_finger3", id=132, color=[0, 0, 255], swap="left_pinky_finger3"), - 133: dict(name="right_pinky_finger4", id=133, color=[0, 0, 255], swap="left_pinky_finger4"), - }, - skeleton_info={ - 0: dict(link=("neck", "right_shoulder"), id=0, color=[255, 0, 0]), - 1: dict(link=("neck", "left_shoulder"), id=1, color=[255, 85, 0]), - 2: dict(link=("right_shoulder", "right_elbow"), id=2, color=[255, 170, 0]), - 3: dict(link=("right_elbow", "right_wrist"), id=3, color=[255, 255, 0]), - 4: dict(link=("left_shoulder", "left_elbow"), id=4, color=[170, 255, 0]), - 5: dict(link=("left_elbow", "left_wrist"), id=5, color=[85, 255, 0]), - 6: dict(link=("neck", "right_hip"), id=6, color=[0, 255, 0]), - 7: dict(link=("right_hip", "right_knee"), id=7, color=[0, 255, 85]), - 8: dict(link=("right_knee", "right_ankle"), id=8, color=[0, 255, 170]), - 9: dict(link=("neck", "left_hip"), id=9, color=[0, 255, 225]), - 10: dict(link=("left_hip", "left_knee"), id=10, color=[0, 170, 255]), - 11: dict(link=("left_knee", "left_ankle"), id=11, color=[0, 85, 255]), - 12: dict(link=("neck", "nose"), id=12, color=[0, 0, 255]), - 13: dict(link=("nose", "right_eye"), id=13, color=[255, 0, 170]), - 14: dict(link=("right_eye", "right_ear"), id=14, color=[170, 0, 255]), - 15: dict(link=("nose", "left_eye"), id=15, color=[255, 0, 255]), - 16: dict(link=("left_eye", "left_ear"), id=16, color=[255, 0, 170]), - 17: dict(link=("left_hand_root", "left_thumb1"), id=17, color=[255, 0, 0]), - 18: dict(link=("left_thumb1", "left_thumb2"), id=18, color=[255, 76, 0]), - 19: dict(link=("left_thumb2", "left_thumb3"), id=19, color=[255, 153, 0]), - 20: dict(link=("left_thumb3", "left_thumb4"), id=20, color=[255, 230, 0]), - 21: dict(link=("left_hand_root", "left_forefinger1"), id=21, color=[204, 255, 0]), - 22: dict(link=("left_forefinger1", "left_forefinger2"), id=22, color=[128, 255, 0]), - 23: dict(link=("left_forefinger2", "left_forefinger3"), id=23, color=[51, 255, 0]), - 24: dict(link=("left_forefinger3", "left_forefinger4"), id=24, color=[0, 255, 26]), - 25: dict(link=("left_hand_root", "left_middle_finger1"), id=25, color=[0, 255, 102]), - 26: dict(link=("left_middle_finger1", "left_middle_finger2"), id=26, color=[0, 255, 178]), - 27: dict(link=("left_middle_finger2", "left_middle_finger3"), id=27, color=[0, 255, 255]), - 28: dict(link=("left_middle_finger3", "left_middle_finger4"), id=28, color=[0, 178, 255]), - 29: dict(link=("left_hand_root", "left_ring_finger1"), id=29, color=[0, 102, 255]), - 30: dict(link=("left_ring_finger1", "left_ring_finger2"), id=30, color=[0, 26, 255]), - 31: dict(link=("left_ring_finger2", "left_ring_finger3"), id=31, color=[51, 0, 255]), - 32: dict(link=("left_ring_finger3", "left_ring_finger4"), id=32, color=[128, 0, 255]), - 33: dict(link=("left_hand_root", "left_pinky_finger1"), id=33, color=[204, 0, 255]), - 34: dict(link=("left_pinky_finger1", "left_pinky_finger2"), id=34, color=[255, 0, 230]), - 35: dict(link=("left_pinky_finger2", "left_pinky_finger3"), id=35, color=[255, 0, 153]), - 36: dict(link=("left_pinky_finger3", "left_pinky_finger4"), id=36, color=[255, 0, 76]), - 37: dict(link=("right_hand_root", "right_thumb1"), id=37, color=[255, 0, 0]), - 38: dict(link=("right_thumb1", "right_thumb2"), id=38, color=[255, 76, 0]), - 39: dict(link=("right_thumb2", "right_thumb3"), id=39, color=[255, 153, 0]), - 40: dict(link=("right_thumb3", "right_thumb4"), id=40, color=[255, 230, 0]), - 41: dict(link=("right_hand_root", "right_forefinger1"), id=41, color=[204, 255, 0]), - 42: dict(link=("right_forefinger1", "right_forefinger2"), id=42, color=[128, 255, 0]), - 43: dict(link=("right_forefinger2", "right_forefinger3"), id=43, color=[51, 255, 0]), - 44: dict(link=("right_forefinger3", "right_forefinger4"), id=44, color=[0, 255, 26]), - 45: dict(link=("right_hand_root", "right_middle_finger1"), id=45, color=[0, 255, 102]), - 46: dict(link=("right_middle_finger1", "right_middle_finger2"), id=46, color=[0, 255, 178]), - 47: dict(link=("right_middle_finger2", "right_middle_finger3"), id=47, color=[255, 255, 255]), - 48: dict(link=("right_middle_finger3", "right_middle_finger4"), id=48, color=[0, 178, 255]), - 49: dict(link=("right_hand_root", "right_ring_finger1"), id=49, color=[0, 102, 255]), - 50: dict(link=("right_ring_finger1", "right_ring_finger2"), id=50, color=[0, 26, 255]), - 51: dict(link=("right_ring_finger2", "right_ring_finger3"), id=51, color=[51, 0, 255]), - 52: dict(link=("right_ring_finger3", "right_ring_finger4"), id=52, color=[128, 0, 255]), - 53: dict(link=("right_hand_root", "right_pinky_finger1"), id=53, color=[204, 0, 255]), - 54: dict(link=("right_pinky_finger1", "right_pinky_finger2"), id=54, color=[255, 0, 230]), - 55: dict(link=("right_pinky_finger2", "right_pinky_finger3"), id=55, color=[255, 0, 153]), - 56: dict(link=("right_pinky_finger3", "right_pinky_finger4"), id=56, color=[255, 0, 76]), - }, -) - - -def convert_coco_to_openpose(keypoints, scores): - assert len(keypoints.shape) == 2 - keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1) - keypoints_info = keypoints_info[None, ...] - - # compute neck joint - neck = np.mean(keypoints_info[:, [5, 6]], axis=1) - - # neck score when visualizing pred - neck[:, 2:3] = np.where( - keypoints_info[:, 5, 2:3] > keypoints_info[:, 6, 2:3], keypoints_info[:, 6, 2:3], keypoints_info[:, 5, 2:3] - ) - new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1) - - mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3] - openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17] - new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx] - keypoints_info = new_keypoints_info - - keypoints, scores = keypoints_info[..., :2].squeeze(0), keypoints_info[..., 2].squeeze(0) - return keypoints, scores diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py deleted file mode 100644 index 439408e6e08acb463b2d5e27c3e4f9193a1d6379..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/merge_datadict.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor -from cosmos_transfer1.utils import log - - -class DataDictMerger(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Merge the dictionary associated with the input keys into data_dict. Only keys in output_keys are merged. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict with dictionary associated with the input keys merged. - """ - for key in self.input_keys: - if key not in data_dict: - log.warning( - f"DataDictMerger dataloader error: missing {key}; data_dict keys: {data_dict.keys()}", - rank0_only=False, - ) - return None - key_dict = data_dict.pop(key) - if key == "depth" and "depth" in self.output_keys: - data_dict["depth"] = key_dict - if key == "keypoint" and "keypoint" in self.output_keys: - data_dict["keypoint"] = key_dict - elif key == "segmentation" and "segmentation" in self.output_keys: - data_dict["segmentation"] = key_dict - for sub_key in key_dict: - if sub_key in self.output_keys and sub_key not in data_dict: - data_dict[sub_key] = key_dict[sub_key] - del key_dict - return data_dict diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py deleted file mode 100644 index b979faa51f83fd33ca6e5121c2f3ff9a9997e007..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/augmentors/text_transforms_for_video.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -from typing import Optional - -import numpy as np -import torch - -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import Augmentor -from cosmos_transfer1.utils import log - - -def pad_and_resize( - arr_np: np.ndarray, ntokens: int, is_mask_all_ones: bool = False -) -> tuple[torch.Tensor, torch.Tensor]: - r"""Function for padding and resizing a numpy array. - Args: - arr (np.ndarray): Input array - ntokens (int): Number of output tokens after padding - is_mask_all_ones (bool): if true, set mask to ones - Returns: - arr_padded (torch.Tensor): Padded output tensor - mask (torch.Tensor): Padding mask - """ - - if isinstance(arr_np, np.ndarray): - arr = torch.from_numpy(arr_np) - elif isinstance(arr_np, torch.Tensor): - arr = arr_np.clone().detach() - else: - raise TypeError("`arr_np` should be a numpy array or torch tensor.") - embed_dim = arr.shape[1] - - arr_padded = torch.zeros(ntokens, embed_dim, device=arr.device, dtype=torch.float32) - - # If the input text is larger than num_text_tokens, clip it. - if arr.shape[0] > ntokens: - arr = arr[0:ntokens] - - mask = torch.LongTensor(ntokens).zero_() - if len(arr.shape) > 1: - mask[0 : arr.shape[0]] = 1 - - if len(arr.shape) > 1: - arr_padded[0 : arr.shape[0]] = arr - - if is_mask_all_ones: - mask.fill_(1) - - return arr_padded, mask - - -class TextTransformForVideo(Augmentor): - def __init__(self, input_keys: dict, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs text transformation. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict with captions and t5 embeddings added - """ - data_source = data_dict["__url__"].meta.source - input_keys_by_source = self.input_keys[data_source] - - if "chunk_index" not in data_dict: - log.warning( - "Chunk_index is not in data_dict, set chunk_index to be 0. This should only happen for sampling." - ) - data_dict["chunk_index"] = 0 # this is for sampling only, whereas decoder is not loaded - try: - windows = data_dict[input_keys_by_source["ai_caption"]]["windows"] - n_windows = len(windows) - chunk_index = data_dict["chunk_index"] - - if chunk_index == n_windows: - # This will only happen when the number of captions does not match number of chunks due to re-transcoding the videos. - log.info( - f"Found {data_dict['orig_num_frames']} in video but captioning is done with videos of {windows[-1]['end_frame']} frames. This mismatch is due to video re-transcoding.", - rank0_only=False, - ) - chunk_index -= 1 - - selected_caption_window = windows[chunk_index] - except Exception as e: - log.warning( - f"TextTransform dataloader error: {data_dict['__url__']}, {data_dict['__key__']}, {data_dict['chunk_index']}\n error {e}", - rank0_only=False, - ) - return None - - try: - if "vila_caption" in selected_caption_window: - caption_type = "vila_caption" - else: - caption_type = random.choices(["long_caption", "short_caption"], weights=[0.95, 0.05], k=1)[0] - data_dict["ai_caption"] = selected_caption_window[caption_type] - except Exception as e: - log.warning( - f"TextTransform dataloader error: {data_dict['__url__']}, {data_dict['__key__']}, {selected_caption_window}\n error {e}", - rank0_only=False, - ) - return None - - if data_dict["ai_caption"] is None: - data_dict["ai_caption"] = "" - del data_dict[input_keys_by_source["ai_caption"]] - - ai_caption_embedding_data = data_dict[input_keys_by_source["ai_caption_embedding"]] - try: - if caption_type in ["vila_caption"]: - t5_embedding = ai_caption_embedding_data[data_dict["chunk_index"]] - else: - t5_embedding = ai_caption_embedding_data[data_dict["chunk_index"]][ - caption_type.replace("_caption", "") - ] # t5_embedding is saved in {"short": array, "long": array} format - except Exception as e: - log.warning( - f"TextTransform dataloader error: {data_dict['__url__']}, {data_dict['__key__']}, {data_dict['chunk_index']}, {len(ai_caption_embedding_data)} \n error {e}", - rank0_only=False, - ) - return None - out_t5, out_t5_mask = pad_and_resize( - t5_embedding, - self.args["t5_tokens"]["num"], - is_mask_all_ones=self.args["is_mask_all_ones"], - ) - data_dict["t5_text_embeddings"] = out_t5 - data_dict["t5_text_mask"] = out_t5_mask - del data_dict[input_keys_by_source["ai_caption_embedding"]] - - return data_dict diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/dataset_utils.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/dataset_utils.py deleted file mode 100644 index 78177165026d5c9a57ed0e475c63b553853d7f92..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/dataset_utils.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Optional, Union - -import torch -import torchvision.transforms.functional as transforms_F -from PIL import Image - - -def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]: - r"""Function for obtaining the image size from the data dict. - - Args: - data_dict (dict): Input data dict - input_keys (list): List of input keys - Returns: - width (int): Width of the input image - height (int): Height of the input image - """ - - data1 = data_dict[input_keys[0]] - if isinstance(data1, Image.Image): - width, height = data1.size - elif isinstance(data1, torch.Tensor): - height, width = data1.size()[-2:] - else: - raise ValueError("data to random crop should be PIL Image or tensor") - - return width, height - - -def obtain_augmentation_size(data_dict: dict, augmentor_cfg: dict) -> Union[int, tuple]: - r"""Function for obtaining size of the augmentation. - When dealing with multi-aspect ratio dataloaders, we need to - find the augmentation size from the aspect ratio of the data. - - Args: - data_dict (dict): Input data dict - augmentor_cfg (dict): Augmentor config - Returns: - aug_size (int): Size of augmentation - """ - aspect_ratio = data_dict["aspect_ratio"] - aug_size = augmentor_cfg["size"][aspect_ratio] - return aug_size - - -class Augmentor: - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - r"""Base augmentor class - - Args: - input_keys (list): List of input keys - output_keys (list): List of output keys - args (dict): Arguments associated with the augmentation - """ - self.input_keys = input_keys - self.output_keys = output_keys - self.args = args - - def __call__(self, *args: Any, **kwds: Any) -> Any: - raise ValueError("Augmentor not implemented") - - -class ResizeSmallestSideAspectPreserving(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs aspect-ratio preserving resizing. - Image is resized to the dimension which has the smaller ratio of (size / target_size). - First we compute (w_img / w_target) and (h_img / h_target) and resize the image - to the dimension that has the smaller of these ratios. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are resized - """ - - if self.output_keys is None: - self.output_keys = self.input_keys - assert self.args is not None, "Please specify args in augmentations" - - img_w, img_h = self.args["img_w"], self.args["img_h"] - - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - scaling_ratio = max((img_w / orig_w), (img_h / orig_h)) - target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5)) - - assert ( - target_size[0] >= img_h and target_size[1] >= img_w - ), f"Resize error. orig {(orig_w, orig_h)} desire {(img_w, img_h)} compute {target_size}" - - for inp_key, out_key in zip(self.input_keys, self.output_keys): - data_dict[out_key] = transforms_F.resize( - data_dict[inp_key], - size=target_size, # type: ignore - interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC), - antialias=True, - ) - - if out_key != inp_key: - del data_dict[inp_key] - return data_dict - - -class CenterCrop(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs center crop. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are center cropped. - We also save the cropping parameters in the aug_params dict - so that it will be used by other transforms. - """ - assert ( - (self.args is not None) and ("img_w" in self.args) and ("img_h" in self.args) - ), "Please specify size in args" - - img_w, img_h = self.args["img_w"], self.args["img_h"] - - orig_w, orig_h = obtain_image_size(data_dict, self.input_keys) - for key in self.input_keys: - data_dict[key] = transforms_F.center_crop(data_dict[key], [img_h, img_w]) - - # We also add the aug params we use. This will be useful for other transforms - crop_x0 = (orig_w - img_w) // 2 - crop_y0 = (orig_h - img_h) // 2 - cropping_params = { - "resize_w": orig_w, - "resize_h": orig_h, - "crop_x0": crop_x0, - "crop_y0": crop_y0, - "crop_w": img_w, - "crop_h": img_h, - } - - if "aug_params" not in data_dict: - data_dict["aug_params"] = dict() - - data_dict["aug_params"]["cropping"] = cropping_params - data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"])) - return data_dict - - -class Normalize(Augmentor): - def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None: - super().__init__(input_keys, output_keys, args) - - def __call__(self, data_dict: dict) -> dict: - r"""Performs data normalization. - - Args: - data_dict (dict): Input data dict - Returns: - data_dict (dict): Output dict where images are center cropped. - """ - assert self.args is not None, "Please specify args" - - mean = self.args["mean"] - std = self.args["std"] - - for key in self.input_keys: - if isinstance(data_dict[key], torch.Tensor): - data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255) - else: - data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor() - - data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std) - return data_dict diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py b/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py deleted file mode 100644 index 0c698dcd663b56bc37cf8f1ba79ebadb68a529c0..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py +++ /dev/null @@ -1,529 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Run this command to interactively debug: -PYTHONPATH=. python cosmos_transfer1/diffusion/datasets/example_transfer_dataset.py -""" - -import os -import pickle -import traceback -import warnings - -import numpy as np -import torch -from decord import VideoReader, cpu -from torch.utils.data import Dataset - -from cosmos_transfer1.diffusion.datasets.augmentor_provider import AUGMENTOR_OPTIONS -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import VIDEO_RES_SIZE_INFO -from cosmos_transfer1.diffusion.inference.inference_utils import detect_aspect_ratio -from cosmos_transfer1.utils.lazy_config import instantiate - -# mappings between control types and corresponding sub-folders names in the data folder -CTRL_TYPE_INFO = { - "keypoint": {"folder": "keypoint", "format": "pickle", "data_dict_key": "keypoint"}, - "depth": {"folder": "depth", "format": "mp4", "data_dict_key": "depth"}, - "lidar": {"folder": "lidar", "format": "mp4", "data_dict_key": "lidar"}, - "hdmap": {"folder": "hdmap", "format": "mp4", "data_dict_key": "hdmap"}, - "seg": {"folder": "seg", "format": "pickle", "data_dict_key": "segmentation"}, - "edge": {"folder": None}, # Canny edge, computed on-the-fly - "vis": {"folder": None}, # Blur, computed on-the-fly - "upscale": {"folder": None}, # Computed on-the-fly -} - - -class ExampleTransferDataset(Dataset): - def __init__(self, dataset_dir, num_frames, resolution, hint_key="control_input_vis", is_train=True): - """Dataset class for loading video-text-to-video generation data with control inputs. - - Args: - dataset_dir (str): Base path to the dataset directory - num_frames (int): Number of consecutive frames to load per sequence - resolution (str): resolution of the target video size - hint_key (str): The hint key for loading the correct control input data modality - is_train (bool): Whether this is for training - - NOTE: in our example dataset we do not have a validation dataset. The is_train flag is kept here for customized configuration. - """ - super().__init__() - self.dataset_dir = dataset_dir - self.sequence_length = num_frames - self.is_train = is_train - self.resolution = resolution - assert ( - resolution in VIDEO_RES_SIZE_INFO.keys() - ), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO." - - # Control input setup with file formats - self.ctrl_type = hint_key.replace("control_input_", "") - self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type] - - # Set up directories - only collect paths - video_dir = os.path.join(self.dataset_dir, "videos") - self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] - self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") - print(f"Finish initializing dataset with {len(self.video_paths)} videos in total.") - - # Set up preprocessing and augmentation - augmentor_name = f"video_ctrlnet_augmentor_{hint_key}" - augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) - self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()} - - def _sample_frames(self, video_path): - """Sample frames from video and get metadata""" - vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) - n_frames = len(vr) - - # Calculate valid start frame range - max_start_idx = n_frames - self.sequence_length - if max_start_idx < 0: # Video is too short - return None, None, None - - # Sample start frame - start_frame = np.random.randint(0, max_start_idx + 1) - frame_ids = list(range(start_frame, start_frame + self.sequence_length)) - - # Load frames - frames = vr.get_batch(frame_ids).asnumpy() - frames = frames.astype(np.uint8) - try: - fps = vr.get_avg_fps() - except Exception: # failed to read FPS - fps = 24 - - return frames, frame_ids, fps - - def _load_control_data(self, sample): - """Load control data for the video clip.""" - data_dict = {} - frame_ids = sample["frame_ids"] - ctrl_path = sample["ctrl_path"] - try: - if self.ctrl_type == "seg": - with open(ctrl_path, "rb") as f: - ctrl_data = pickle.load(f) - # key should match line 982 at cosmos_transfer1/diffusion/datasets/augmentors/control_input.py - data_dict["segmentation"] = ctrl_data - elif self.ctrl_type == "keypoint": - with open(ctrl_path, "rb") as f: - ctrl_data = pickle.load(f) - data_dict["keypoint"] = ctrl_data - elif self.ctrl_type == "depth": - vr = VideoReader(ctrl_path, ctx=cpu(0)) - # Ensure the depth video has the same number of frames - assert len(vr) >= frame_ids[-1] + 1, f"Depth video {ctrl_path} has fewer frames than main video" - - # Load the corresponding frames - depth_frames = vr.get_batch(frame_ids).asnumpy() # [T,H,W,C] - depth_frames = torch.from_numpy(depth_frames).permute(3, 0, 1, 2) # [C,T,H,W], same as rgb video - data_dict["depth"] = { - "video": depth_frames, - "frame_start": frame_ids[0], - "frame_end": frame_ids[-1], - } - elif self.ctrl_type == "lidar": - vr = VideoReader(ctrl_path, ctx=cpu(0)) - # Ensure the lidar depth video has the same number of frames - assert len(vr) >= frame_ids[-1] + 1, f"Lidar video {ctrl_path} has fewer frames than main video" - # Load the corresponding frames - lidar_frames = vr.get_batch(frame_ids).asnumpy() # [T,H,W,C] - lidar_frames = torch.from_numpy(lidar_frames).permute(3, 0, 1, 2) # [C,T,H,W], same as rgb video - data_dict["lidar"] = { - "video": lidar_frames, - "frame_start": frame_ids[0], - "frame_end": frame_ids[-1], - } - elif self.ctrl_type == "hdmap": - vr = VideoReader(ctrl_path, ctx=cpu(0)) - # Ensure the hdmap video has the same number of frames - assert len(vr) >= frame_ids[-1] + 1, f"Hdmap video {ctrl_path} has fewer frames than main video" - # Load the corresponding frames - hdmap_frames = vr.get_batch(frame_ids).asnumpy() # [T,H,W,C] - hdmap_frames = torch.from_numpy(hdmap_frames).permute(3, 0, 1, 2) # [C,T,H,W], same as rgb video - data_dict["hdmap"] = { - "video": hdmap_frames, - "frame_start": frame_ids[0], - "frame_end": frame_ids[-1], - } - - except Exception as e: - warnings.warn(f"Failed to load control data from {ctrl_path}: {str(e)}") - return None - - return data_dict - - def __getitem__(self, index): - max_retries = 3 - for _ in range(max_retries): - try: - video_path = self.video_paths[index] - video_name = os.path.basename(video_path).replace(".mp4", "") - - # Sample frames - frames, frame_ids, fps = self._sample_frames(video_path) - if frames is None: # Invalid video or too short - index = np.random.randint(len(self.video_paths)) - continue - - data = dict() - - # Process video frames - video = torch.from_numpy(frames).permute(3, 0, 1, 2) # [T,H,W,C] -> [C,T,H,W] - aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # expects (W, H) - - # Basic data - data["video"] = video - data["aspect_ratio"] = aspect_ratio - - # Load T5 embeddings - if self.ctrl_type in ["hdmap", "lidar"]: - # AV data load captions differently - data["video_name"] = { - "video_path": video_path, - "t5_embedding_path": os.path.join(self.t5_dir, f"{video_name}.pkl"), - "start_frame_id": str(frame_ids[0]), - } - with open(data["video_name"]["t5_embedding_path"], "rb") as f: - t5_embedding = pickle.load(f)["pickle"]["ground_truth"]["embeddings"]["t5_xxl"] - # Ensure t5_embedding is a numpy array - if isinstance(t5_embedding, list): - t5_embedding = np.array(t5_embedding[0] if len(t5_embedding) > 0 else t5_embedding) - data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) # .cuda() - data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) # .cuda() - else: - data["video_name"] = { - "video_path": video_path, - "t5_embedding_path": os.path.join(self.t5_dir, f"{video_name}.pickle"), - "start_frame_id": str(frame_ids[0]), - } - with open(data["video_name"]["t5_embedding_path"], "rb") as f: - t5_embedding = pickle.load(f) - # Ensure t5_embedding is a numpy array - if isinstance(t5_embedding, list): - t5_embedding = np.array(t5_embedding[0] if len(t5_embedding) > 0 else t5_embedding) - data["t5_text_embeddings"] = torch.from_numpy(t5_embedding) # .cuda() - data["t5_text_mask"] = torch.ones(512, dtype=torch.int64) # .cuda() - - # Add metadata - data["fps"] = fps - data["frame_start"] = frame_ids[0] - data["frame_end"] = frame_ids[-1] + 1 - data["num_frames"] = self.sequence_length - data["image_size"] = torch.tensor([704, 1280, 704, 1280]) # .cuda() - data["padding_mask"] = torch.zeros(1, 704, 1280) # .cuda() - - if self.ctrl_type: - ctrl_data = self._load_control_data( - { - "ctrl_path": os.path.join( - self.dataset_dir, - self.ctrl_data_pth_config["folder"], - f"{video_name}.{self.ctrl_data_pth_config['format']}", - ) - if self.ctrl_data_pth_config["folder"] is not None - else None, - "frame_ids": frame_ids, - } - ) - if ctrl_data is None: # Control data loading failed - index = np.random.randint(len(self.video_paths)) - continue - data.update(ctrl_data) - - # The ctrl_data above is the 'raw' data loaded (e.g. a loaded segmentation pkl). - # Next, we process it into the control input "video" tensor that the model expects. - # This is done in the augmentor. - for _, aug_fn in self.augmentor.items(): - data = aug_fn(data) - - return data - - except Exception: - warnings.warn( - f"Invalid data encountered: {self.video_paths[index]}. Skipped " - f"(by randomly sampling another sample in the same dataset)." - ) - warnings.warn("FULL TRACEBACK:") - warnings.warn(traceback.format_exc()) - if _ == max_retries - 1: - raise RuntimeError(f"Failed to load data after {max_retries} attempts") - index = np.random.randint(len(self.video_paths)) - return - - def __len__(self): - return len(self.video_paths) - - def __str__(self): - return f"{len(self.video_paths)} samples from {self.dataset_dir}" - - -class AVTransferDataset(ExampleTransferDataset): - def __init__( - self, - dataset_dir, - num_frames, - resolution, - view_keys, - hint_key="control_input_hdmap", - sample_n_views=-1, - caption_view_idx_map=None, - is_train=True, - load_mv_emb=False, - ): - """Dataset class for loading video-text-to-video generation data with control inputs. - - Args: - dataset_dir (str): Base path to the dataset directory - num_frames (int): Number of consecutive frames to load per sequence - resolution (str): resolution of the target video size - hint_key (str): The hint key for loading the correct control input data modality - view_keys (list[str]): list of view names that the dataloader should load - sample_n_views (int): Number of views to sample - caption_view_idx_map (dict): Optional dictionary mapping index in view_keys to index in model.view_embeddings - is_train (bool): Whether this is for training - load_mv_emb (bool): Whether to load t5 embeddings for all views, or only for front view - NOTE: in our example dataset we do not have a validation dataset. The is_train flag is kept here for customized configuration. - """ - super(ExampleTransferDataset, self).__init__() - self.dataset_dir = dataset_dir - self.sequence_length = num_frames - self.is_train = is_train - self.resolution = resolution - self.view_keys = view_keys - self.load_mv_emb = load_mv_emb - assert ( - resolution in VIDEO_RES_SIZE_INFO.keys() - ), "The provided resolution cannot be found in VIDEO_RES_SIZE_INFO." - - # Control input setup with file formats - self.ctrl_type = hint_key.replace("control_input_", "") - self.ctrl_data_pth_config = CTRL_TYPE_INFO[self.ctrl_type] - - # Set up directories - only collect paths - video_dir = os.path.join(self.dataset_dir, "videos", "pinhole_front") - self.video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(".mp4")] - self.t5_dir = os.path.join(self.dataset_dir, "t5_xxl") - - cache_dir = os.path.join(self.dataset_dir, "cache") - self.prefix_t5_embeddings = {} - for view_key in view_keys: - with open(os.path.join(cache_dir, f"prefix_{view_key}.pkl"), "rb") as f: - self.prefix_t5_embeddings[view_key] = pickle.load(f) - if caption_view_idx_map is None: - self.caption_view_idx_map = dict([(i, i) for i in range(len(self.view_keys))]) - else: - self.caption_view_idx_map = caption_view_idx_map - self.sample_n_views = sample_n_views - - print(f"Finish initializing dataset with {len(self.video_paths)} videos in total.") - - # Set up preprocessing and augmentation - augmentor_name = f"video_ctrlnet_augmentor_{hint_key}" - augmentor_cfg = AUGMENTOR_OPTIONS[augmentor_name](resolution=resolution) - self.augmentor = {k: instantiate(v) for k, v in augmentor_cfg.items()} - - def _load_video(self, video_path, frame_ids): - vr = VideoReader(video_path, ctx=cpu(0), num_threads=2) - assert (np.array(frame_ids) < len(vr)).all() - assert (np.array(frame_ids) >= 0).all() - vr.seek(0) - frame_data = vr.get_batch(frame_ids).asnumpy() - try: - fps = vr.get_avg_fps() - except Exception: # failed to read FPS - fps = 24 - return frame_data, fps - - def __getitem__(self, index): - max_retries = 3 - for _ in range(max_retries): - try: - video_path = self.video_paths[index] - video_name = os.path.basename(video_path).replace(".mp4", "") - - data = dict() - ctrl_videos = [] - videos = [] - t5_embeddings = [] - t5_masks = [] - view_indices = [i for i in range(len(self.view_keys))] - view_indices_conditioning = [] - if self.sample_n_views > 1: - sampled_idx = np.random.choice( - np.arange(1, len(view_indices)), - size=min(self.sample_n_views - 1, len(view_indices) - 1), - replace=False, - ) - sampled_idx = np.concatenate( - [ - [ - 0, - ], - sampled_idx, - ] - ) - sampled_idx.sort() - view_indices = sampled_idx.tolist() - - frame_ids = None - fps = None - for view_index in view_indices: - view_key = self.view_keys[view_index] - if frame_ids is None: - frames, frame_ids, fps = self._sample_frames(video_path) - if frames is None: # Invalid video or too short - raise Exception(f"Failed to load frames {video_path}") - - else: - frames, fps = self._load_video( - os.path.join(self.dataset_dir, "videos", view_key, os.path.basename(video_path)), frame_ids - ) - # Process video frames - video = torch.from_numpy(frames) - - video = video.permute(3, 0, 1, 2) # Rearrange from [T, C, H, W] to [C, T, H, W] - aspect_ratio = detect_aspect_ratio((video.shape[3], video.shape[2])) # expects (W, H) - videos.append(video) - - if video_name[-2] == "_" and video_name[-1].isdigit(): - video_name_emb = video_name[:-2] - else: - video_name_emb = video_name - - if self.load_mv_emb or view_key == "pinhole_front": - t5_embedding_path = os.path.join(self.dataset_dir, "t5_xxl", view_key, f"{video_name_emb}.pkl") - with open(t5_embedding_path, "rb") as f: - t5_embedding = pickle.load(f)[0] - if self.load_mv_emb: - t5_embedding = np.concatenate([self.prefix_t5_embeddings[view_key], t5_embedding], axis=0) - else: - # use camera prompt - t5_embedding = self.prefix_t5_embeddings[view_key] - - t5_embedding = torch.from_numpy(t5_embedding) - t5_mask = torch.ones(t5_embedding.shape[0], dtype=torch.int64) - if t5_embedding.shape[0] < 512: - t5_embedding = torch.cat([t5_embedding, torch.zeros(512 - t5_embedding.shape[0], 1024)], dim=0) - t5_mask = torch.cat([t5_mask, torch.zeros(512 - t5_mask.shape[0])], dim=0) - else: - t5_embedding = t5_embedding[:512] - t5_mask = t5_mask[:512] - t5_embeddings.append(t5_embedding) - t5_masks.append(t5_mask) - caption_viewid = self.caption_view_idx_map[view_index] - view_indices_conditioning.append(torch.ones(video.shape[1]) * caption_viewid) - - if self.ctrl_type: - v_ctrl_data = self._load_control_data( - { - "ctrl_path": os.path.join( - self.dataset_dir, - self.ctrl_data_pth_config["folder"], - view_key, - f"{video_name}.{self.ctrl_data_pth_config['format']}", - ) - if self.ctrl_data_pth_config["folder"] is not None - else None, - "frame_ids": frame_ids, - } - ) - if v_ctrl_data is None: # Control data loading failed - raise Exception("Failed to load v_ctrl_data") - ctrl_videos.append(v_ctrl_data[self.ctrl_type]["video"]) - - video = torch.cat(videos, dim=1) - ctrl_videos = torch.cat(ctrl_videos, dim=1) - t5_embedding = torch.cat(t5_embeddings, dim=0) - view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0) - - # Basic data - data["video"] = video - data["video_name"] = video_name - data["aspect_ratio"] = aspect_ratio - data["t5_text_embeddings"] = t5_embedding - data["t5_text_mask"] = torch.cat(t5_masks) - data["view_indices"] = view_indices_conditioning.contiguous() - data["frame_repeat"] = torch.zeros(len(view_indices)) - # Add metadata - data["fps"] = fps - data["frame_start"] = frame_ids[0] - data["frame_end"] = frame_ids[-1] + 1 - data["num_frames"] = self.sequence_length - data["image_size"] = torch.tensor([704, 1280, 704, 1280]) - data["padding_mask"] = torch.zeros(1, 704, 1280) - data[self.ctrl_type] = dict() - data[self.ctrl_type]["video"] = ctrl_videos - - # The ctrl_data above is the 'raw' data loaded (e.g. a loaded lidar pkl). - # Next, we process it into the control input "video" tensor that the model expects. - # This is done in the augmentor. - for _, aug_fn in self.augmentor.items(): - data = aug_fn(data) - - return data - - except Exception: - warnings.warn( - f"Invalid data encountered: {self.video_paths[index]}. Skipped " - f"(by randomly sampling another sample in the same dataset)." - ) - warnings.warn("FULL TRACEBACK:") - warnings.warn(traceback.format_exc()) - if _ == max_retries - 1: - raise RuntimeError(f"Failed to load data after {max_retries} attempts") - index = np.random.randint(len(self.video_paths)) - return - - -if __name__ == "__main__": - """ - Sanity check for the dataset. - """ - control_input_key = "control_input_lidar" - visualize_control_input = True - - dataset = AVTransferDataset( - dataset_dir="datasets/waymo_transfer1", - view_keys=["pinhole_front"], - hint_key=control_input_key, - num_frames=121, - resolution="720", - is_train=True, - ) - print("finished init dataset") - indices = [0, 12, 100, -1] - for idx in indices: - data = dataset[idx] - print( - ( - f"{idx=} " - f"{data['frame_start']=}\n" - f"{data['frame_end']=}\n" - f"{data['video'].sum()=}\n" - f"{data['video'].shape=}\n" - f"{data[control_input_key].shape=}\n" # should match the video shape - f"{data['video_name']=}\n" - f"{data['t5_text_embeddings'].shape=}\n" - "---" - ) - ) - if visualize_control_input: - import imageio - - control_input_tensor = data[control_input_key].permute(1, 2, 3, 0).cpu().numpy() - video_name = f"{control_input_key}.mp4" - imageio.mimsave(video_name, control_input_tensor, fps=24) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/batch_ops.py b/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/batch_ops.py deleted file mode 100644 index 09440b34a95b1708d2154376f2a0202a533cb3b2..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/batch_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Functions for performing operations with broadcasting to the right axis -# -# Example -# input1: tensor of size (N1, N2) -# input2: tensor of size (N1, N2, N3, N4) -# batch_mul(input1, input2) = input1[:, :, None, None] * input2 -# -# If the common dimensions don't match, we raise an assertion error. - -from torch import Tensor - - -def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: - ndims1 = x.ndim - ndims2 = y.ndim - - common_ndims = min(ndims1, ndims2) - for axis in range(common_ndims): - assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) - - if ndims1 < ndims2: - x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) - elif ndims2 < ndims1: - y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) - - return x, y - - -def batch_mul(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x * y diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/multi_step.py b/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/multi_step.py deleted file mode 100644 index e33bdb37e44a158870420bf82b0d8dd017ad6d2d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/multi_step.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Impl of multistep methods to solve the ODE in the diffusion model. -""" - -from typing import Callable, List, Tuple - -import torch - -from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import reg_x0_euler_step, res_x0_rk2_step - - -def order2_fn( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor -) -> Tuple[torch.Tensor, List[torch.Tensor]]: - """ - impl the second order multistep method in https://arxiv.org/pdf/2308.02157 - Adams Bashforth approach! - """ - if x0_preds: - x0_s1, s1 = x0_preds[0] - x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) - else: - x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] - return x_t, [(x0_s, s)] - - -# key: method name, value: method function -# key: order + algorithm name -MULTISTEP_FNs = { - "2ab": order2_fn, -} - - -def get_multi_step_fn(name: str) -> Callable: - if name in MULTISTEP_FNs: - return MULTISTEP_FNs[name] - methods = "\n\t".join(MULTISTEP_FNs.keys()) - raise RuntimeError("Only support multistep method\n" + methods) - - -def is_multi_step_fn_supported(name: str) -> bool: - """ - Check if the multistep method is supported. - """ - return name in MULTISTEP_FNs diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/runge_kutta.py b/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/runge_kutta.py deleted file mode 100644 index 406b390b682942f2d24b46c0e9f8720e6d846b8b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/functional/runge_kutta.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Tuple - -import torch - -from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul - - -def phi1(t: torch.Tensor) -> torch.Tensor: - """ - Compute the first order phi function: (exp(t) - 1) / t. - - Args: - t: Input tensor. - - Returns: - Tensor: Result of phi1 function. - """ - input_dtype = t.dtype - t = t.to(dtype=torch.float64) - return (torch.expm1(t) / t).to(dtype=input_dtype) - - -def phi2(t: torch.Tensor) -> torch.Tensor: - """ - Compute the second order phi function: (phi1(t) - 1) / t. - - Args: - t: Input tensor. - - Returns: - Tensor: Result of phi2 function. - """ - input_dtype = t.dtype - t = t.to(dtype=torch.float64) - return ((phi1(t) - 1.0) / t).to(dtype=input_dtype) - - -def res_x0_rk2_step( - x_s: torch.Tensor, - t: torch.Tensor, - s: torch.Tensor, - x0_s: torch.Tensor, - s1: torch.Tensor, - x0_s1: torch.Tensor, -) -> torch.Tensor: - """ - Perform a residual-based 2nd order Runge-Kutta step. - - Args: - x_s: Current state tensor. - t: Target time tensor. - s: Current time tensor. - x0_s: Prediction at current time. - s1: Intermediate time tensor. - x0_s1: Prediction at intermediate time. - - Returns: - Tensor: Updated state tensor. - - Raises: - AssertionError: If step size is too small. - """ - s = -torch.log(s) - t = -torch.log(t) - m = -torch.log(s1) - - dt = t - s - assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" - assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" - - c2 = (m - s) / dt - phi1_val, phi2_val = phi1(-dt), phi2(-dt) - - # Handle edge case where t = s = m - b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) - b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) - - return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1)) - - -def reg_x0_euler_step( - x_s: torch.Tensor, - s: torch.Tensor, - t: torch.Tensor, - x0_s: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a regularized Euler step based on x0 prediction. - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_s: Prediction at current time. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and current prediction. - """ - coef_x0 = (s - t) / s - coef_xs = t / s - return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s - - -def reg_eps_euler_step( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, eps_s: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a regularized Euler step based on epsilon prediction. - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - eps_s: Epsilon prediction at current time. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and current x0 prediction. - """ - return x_s + batch_mul(eps_s, t - s), x_s + batch_mul(eps_s, 0 - s) - - -def rk1_euler( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a first-order Runge-Kutta (Euler) step. - - Recommended for diffusion models with guidance or model undertrained - Usually more stable at the cost of a bit slower convergence. - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_fn: Function to compute x0 prediction. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. - """ - x0_s = x0_fn(x_s, s) - return reg_x0_euler_step(x_s, s, t, x0_s) - - -def rk2_mid_stable( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a stable second-order Runge-Kutta (midpoint) step. - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_fn: Function to compute x0 prediction. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. - """ - s1 = torch.sqrt(s * t) - x_s1, _ = rk1_euler(x_s, s, s1, x0_fn) - - x0_s1 = x0_fn(x_s1, s1) - return reg_x0_euler_step(x_s, s, t, x0_s1) - - -def rk2_mid(x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a second-order Runge-Kutta (midpoint) step. - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_fn: Function to compute x0 prediction. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and x0 prediction. - """ - s1 = torch.sqrt(s * t) - x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) - - x0_s1 = x0_fn(x_s1, s1) - - return res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1), x0_s1 - - -def rk_2heun_naive( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a naive second-order Runge-Kutta (Heun's method) step. - Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis - Recommended for diffusion models without guidance and relative large NFE - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_fn: Function to compute x0 prediction. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and current state. - """ - x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) - eps_s = batch_mul(1.0 / s, x_t - x0_s) - x0_t = x0_fn(x_t, t) - eps_t = batch_mul(1.0 / t, x_t - x0_t) - - avg_eps = (eps_s + eps_t) / 2 - - return reg_eps_euler_step(x_s, s, t, avg_eps) - - -def rk_2heun_edm( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a naive second-order Runge-Kutta (Heun's method) step. - Impl based no EDM second order Heun method - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_fn: Function to compute x0 prediction. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and current state. - """ - x_t, x0_s = rk1_euler(x_s, s, t, x0_fn) - x0_t = x0_fn(x_t, t) - - avg_x0 = (x0_s + x0_t) / 2 - - return reg_x0_euler_step(x_s, s, t, avg_x0) - - -def rk_3kutta_naive( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_fn: Callable -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a naive third-order Runge-Kutta step. - Impl based on rho-rk-deis solvers, https://github.com/qsh-zh/deis - Recommended for diffusion models without guidance and relative large NFE - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_fn: Function to compute x0 prediction. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and current state. - """ - c2, c3 = 0.5, 1.0 - a31, a32 = -1.0, 2.0 - b1, b2, b3 = 1.0 / 6, 4.0 / 6, 1.0 / 6 - - delta = t - s - - s1 = c2 * delta + s - s2 = c3 * delta + s - x_s1, x0_s = rk1_euler(x_s, s, s1, x0_fn) - eps_s = batch_mul(1.0 / s, x_s - x0_s) - x0_s1 = x0_fn(x_s1, s1) - eps_s1 = batch_mul(1.0 / s1, x_s1 - x0_s1) - - _eps = a31 * eps_s + a32 * eps_s1 - x_s2, _ = reg_eps_euler_step(x_s, s, s2, _eps) - - x0_s2 = x0_fn(x_s2, s2) - eps_s2 = batch_mul(1.0 / s2, x_s2 - x0_s2) - - avg_eps = b1 * eps_s + b2 * eps_s1 + b3 * eps_s2 - return reg_eps_euler_step(x_s, s, t, avg_eps) - - -# key : order + name -RK_FNs = { - "1euler": rk1_euler, - "2mid": rk2_mid, - "2mid_stable": rk2_mid_stable, - "2heun_edm": rk_2heun_edm, - "2heun_naive": rk_2heun_naive, - "3kutta_naive": rk_3kutta_naive, -} - - -def get_runge_kutta_fn(name: str) -> Callable: - """ - Get the specified Runge-Kutta function. - - Args: - name: Name of the Runge-Kutta method. - - Returns: - Callable: The specified Runge-Kutta function. - - Raises: - RuntimeError: If the specified method is not supported. - """ - if name in RK_FNs: - return RK_FNs[name] - methods = "\n\t".join(RK_FNs.keys()) - raise RuntimeError(f"Only support the following Runge-Kutta methods:\n\t{methods}") - - -def is_runge_kutta_fn_supported(name: str) -> bool: - """ - Check if the specified Runge-Kutta function is supported. - - Args: - name: Name of the Runge-Kutta method. - - Returns: - bool: True if the method is supported, False otherwise. - """ - return name in RK_FNs diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/modules/denoiser_scaling.py b/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/modules/denoiser_scaling.py deleted file mode 100644 index f4fb3df0f38d52de317177c248e22707a899beb4..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/modules/denoiser_scaling.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch - - -class EDMScaling: - def __init__(self, sigma_data: float = 0.5): - self.sigma_data = sigma_data - - def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 - c_noise = 0.25 * sigma.log() - return c_skip, c_out, c_in, c_noise diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py b/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py deleted file mode 100644 index 4e0d70fadbbe7e525deea72acf58d72f8c8be7dd..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py +++ /dev/null @@ -1,283 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -A general framework for various sampling algorithm from a diffusion model. -Impl based on -* Refined Exponential Solver (RES) in https://arxiv.org/pdf/2308.02157 -* also clude other impl, DDIM, DEIS, DPM-Solver, EDM sampler. -Most of sampling algorihtm, Runge-Kutta, Multi-step, etc, can be impl in this framework by \ - adding new step function in get_runge_kutta_fn or get_multi_step_fn. -""" - -import math -from typing import Any, Callable, List, Literal, Optional, Tuple, Union - -import attrs -import torch - -from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported -from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported -from cosmos_transfer1.utils.ddp_config import make_freezable - -COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"] - - -@make_freezable -@attrs.define(slots=False) -class SolverConfig: - is_multi: bool = False - rk: str = "2mid" - multistep: str = "2ab" - # following parameters control stochasticity, see EDM paper - # BY default, we use deterministic with no stochasticity - s_churn: float = 0.0 - s_t_max: float = float("inf") - s_t_min: float = 0.05 - s_noise: float = 1.0 - - -@make_freezable -@attrs.define(slots=False) -class SolverTimestampConfig: - nfe: int = 50 - t_min: float = 0.002 - t_max: float = 80.0 - order: float = 7.0 - is_forward: bool = False # whether generate forward or backward timestamps - - -@make_freezable -@attrs.define(slots=False) -class SamplerConfig: - solver: SolverConfig = attrs.field(factory=SolverConfig) - timestamps: SolverTimestampConfig = attrs.field(factory=SolverTimestampConfig) - sample_clean: bool = True # whether run one last step to generate clean image - - -def get_rev_ts( - t_min: float, t_max: float, num_steps: int, ts_order: Union[int, float], is_forward: bool = False -) -> torch.Tensor: - """ - Generate a sequence of reverse time steps. - - Args: - t_min (float): The minimum time value. - t_max (float): The maximum time value. - num_steps (int): The number of time steps to generate. - ts_order (Union[int, float]): The order of the time step progression. - is_forward (bool, optional): If True, returns the sequence in forward order. Defaults to False. - - Returns: - torch.Tensor: A tensor containing the generated time steps in reverse or forward order. - - Raises: - ValueError: If `t_min` is not less than `t_max`. - TypeError: If `ts_order` is not an integer or float. - """ - if t_min >= t_max: - raise ValueError("t_min must be less than t_max") - - if not isinstance(ts_order, (int, float)): - raise TypeError("ts_order must be an integer or float") - - step_indices = torch.arange(num_steps + 1, dtype=torch.float64) - time_steps = ( - t_max ** (1 / ts_order) + step_indices / num_steps * (t_min ** (1 / ts_order) - t_max ** (1 / ts_order)) - ) ** ts_order - - if is_forward: - return time_steps.flip(dims=(0,)) - - return time_steps - - -class Sampler(torch.nn.Module): - def __init__(self, cfg: Optional[SamplerConfig] = None): - super().__init__() - if cfg is None: - cfg = SamplerConfig() - self.cfg = cfg - - @torch.no_grad() - def forward( - self, - x0_fn: Callable, - x_sigma_max: torch.Tensor, - num_steps: int = 35, - sigma_min: float = 0.002, - sigma_max: float = 70, - rho: float = 7, - S_churn: float = 0, - S_min: float = 0, - S_max: float = float("inf"), - S_noise: float = 1, - solver_option: str = "2ab", - ) -> torch.Tensor: - in_dtype = x_sigma_max.dtype - - def float64_x0_fn(x_B_StateShape: torch.Tensor, t_B: torch.Tensor) -> torch.Tensor: - return x0_fn(x_B_StateShape.to(in_dtype), t_B.to(in_dtype)).to(torch.float64) - - is_multistep = is_multi_step_fn_supported(solver_option) - is_rk = is_runge_kutta_fn_supported(solver_option) - assert is_multistep or is_rk, f"Only support multistep or Runge-Kutta method, got {solver_option}" - - solver_cfg = SolverConfig( - s_churn=S_churn, - s_t_max=S_max, - s_t_min=S_min, - s_noise=S_noise, - is_multi=is_multistep, - rk=solver_option, - multistep=solver_option, - ) - timestamps_cfg = SolverTimestampConfig(nfe=num_steps, t_min=sigma_min, t_max=sigma_max, order=rho) - sampler_cfg = SamplerConfig(solver=solver_cfg, timestamps=timestamps_cfg, sample_clean=True) - - return self._forward_impl(float64_x0_fn, x_sigma_max, sampler_cfg).to(in_dtype) - - @torch.no_grad() - def _forward_impl( - self, - denoiser_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - noisy_input_B_StateShape: torch.Tensor, - sampler_cfg: Optional[SamplerConfig] = None, - callback_fns: Optional[List[Callable]] = None, - ) -> torch.Tensor: - """ - Internal implementation of the forward pass. - - Args: - denoiser_fn: Function to denoise the input. - noisy_input_B_StateShape: Input tensor with noise. - sampler_cfg: Configuration for the sampler. - callback_fns: List of callback functions to be called during sampling. - - Returns: - torch.Tensor: Denoised output tensor. - """ - sampler_cfg = self.cfg if sampler_cfg is None else sampler_cfg - solver_order = 1 if sampler_cfg.solver.is_multi else int(sampler_cfg.solver.rk[0]) - num_timestamps = sampler_cfg.timestamps.nfe // solver_order - - sigmas_L = get_rev_ts( - sampler_cfg.timestamps.t_min, sampler_cfg.timestamps.t_max, num_timestamps, sampler_cfg.timestamps.order - ).to(noisy_input_B_StateShape.device) - - denoised_output = differential_equation_solver( - denoiser_fn, sigmas_L, sampler_cfg.solver, callback_fns=callback_fns - )(noisy_input_B_StateShape) - - if sampler_cfg.sample_clean: - # Override denoised_output with fully denoised version - ones = torch.ones(denoised_output.size(0), device=denoised_output.device, dtype=denoised_output.dtype) - denoised_output = denoiser_fn(denoised_output, sigmas_L[-1] * ones) - - return denoised_output - - -def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any) -> Any: - """ - Implements a for loop with a function. - - Args: - lower: Lower bound of the loop (inclusive). - upper: Upper bound of the loop (exclusive). - body_fun: Function to be applied in each iteration. - init_val: Initial value for the loop. - - Returns: - The final result after all iterations. - """ - val = init_val - for i in range(lower, upper): - val = body_fun(i, val) - return val - - -def differential_equation_solver( - x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - sigmas_L: torch.Tensor, - solver_cfg: SolverConfig, - callback_fns: Optional[List[Callable]] = None, -) -> Callable[[torch.Tensor], torch.Tensor]: - """ - Creates a differential equation solver function. - - Args: - x0_fn: Function to compute x0 prediction. - sigmas_L: Tensor of sigma values with shape [L,]. - solver_cfg: Configuration for the solver. - callback_fns: Optional list of callback functions. - - Returns: - A function that solves the differential equation. - """ - num_step = len(sigmas_L) - 1 - - if solver_cfg.is_multi: - update_step_fn = get_multi_step_fn(solver_cfg.multistep) - else: - update_step_fn = get_runge_kutta_fn(solver_cfg.rk) - - eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) - - def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: - """ - Samples from the differential equation. - - Args: - input_xT_B_StateShape: Input tensor with shape [B, StateShape]. - - Returns: - Output tensor with shape [B, StateShape]. - """ - ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float64) - - def step_fn( - i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] - ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: - input_x_B_StateShape, x0_preds = state - sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] - - # algorithm 2: line 4-6 - if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max: - hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 - input_x_B_StateShape = input_x_B_StateShape + ( - hat_sigma_cur_0**2 - sigma_cur_0**2 - ).sqrt() * solver_cfg.s_noise * torch.randn_like(input_x_B_StateShape) - sigma_cur_0 = hat_sigma_cur_0 - - if solver_cfg.is_multi: - x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) - output_x_B_StateShape, x0_preds = update_step_fn( - input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds - ) - else: - output_x_B_StateShape, x0_preds = update_step_fn( - input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn - ) - - if callback_fns: - for callback_fn in callback_fns: - callback_fn(**locals()) - - return output_x_B_StateShape, x0_preds - - x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None]) - return x_at_eps - - return sample_fn diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/types.py b/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/types.py deleted file mode 100644 index a209db0eba28a8d8bcb527bfbaca6f5e361ace14..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/diffusion/types.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Optional - -import torch - - -@dataclass -class DenoisePrediction: - x0: torch.Tensor # clean data prediction - eps: Optional[torch.Tensor] = None # noise prediction - logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/functional/batch_ops.py b/cosmos-transfer1/cosmos_transfer1/diffusion/functional/batch_ops.py deleted file mode 100644 index a72b24097f7cc9e7e6a8b324919455131bf84d47..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/functional/batch_ops.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Functions for performing operations with broadcasting to the right axis -# -# Example -# input1: tensor of size (N1, N2) -# input2: tensor of size (N1, N2, N3, N4) -# batch_mul(input1, input2) = input1[:, :, None, None] * input2 -# -# If the common dimensions don't match, we raise an assertion error. - -from torch import Tensor - - -def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: - ndims1 = x.ndim - ndims2 = y.ndim - - common_ndims = min(ndims1, ndims2) - for axis in range(common_ndims): - assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) - - if ndims1 < ndims2: - x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) - elif ndims2 < ndims1: - y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) - - return x, y - - -def batch_add(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x + y - - -def batch_mul(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x * y - - -def batch_sub(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x - y - - -def batch_div(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x / y diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/inference_utils.py b/cosmos-transfer1/cosmos_transfer1/diffusion/inference/inference_utils.py deleted file mode 100644 index ebe8eee2820bac5668d9f53448e2e52d15de4a7c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/inference_utils.py +++ /dev/null @@ -1,1285 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import json -import os -from contextlib import contextmanager -from typing import Any, Dict, List, NamedTuple, Optional, Tuple - -import cv2 -import einops -import imageio -import numpy as np -import torch -import torchvision.transforms.functional as transforms_F -from einops import rearrange - -from cosmos_transfer1.auxiliary.guardrail.common.io_utils import save_video -from cosmos_transfer1.checkpoints import ( - DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, - HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, - VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, -) -from cosmos_transfer1.diffusion.config.transfer.augmentors import BilateralOnlyBlurAugmentorConfig -from cosmos_transfer1.diffusion.datasets.augmentors.control_input import get_augmentor_for_eval -from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel -from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel -from cosmos_transfer1.diffusion.model.model_v2w_multiview import DiffusionV2WMultiviewModel -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.config_helper import get_config_module, override -from cosmos_transfer1.utils.io import load_from_fileobj - -TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) -if TORCH_VERSION >= (1, 11): - from torch.ao import quantization - from torch.ao.quantization import FakeQuantizeBase, ObserverBase -elif ( - TORCH_VERSION >= (1, 8) - and hasattr(torch.quantization, "FakeQuantizeBase") - and hasattr(torch.quantization, "ObserverBase") -): - from torch import quantization - from torch.quantization import FakeQuantizeBase, ObserverBase - -DEFAULT_AUGMENT_SIGMA = 0.001 -NUM_MAX_FRAMES = 5000 -VIDEO_RES_SIZE_INFO = { - "1,1": (960, 960), - "4,3": (960, 704), - "3,4": (704, 960), - "16,9": (1280, 704), - "9,16": (704, 1280), -} - -# Default model names for each control type -default_model_names = { - "vis": VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "seg": SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "edge": EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "depth": DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "keypoint": KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "upscale": UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, - "hdmap": HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - "lidar": LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, -} - -default_distilled_model_names = { - "edge": EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, -} - - -class _IncompatibleKeys( - NamedTuple( - "IncompatibleKeys", - [ - ("missing_keys", List[str]), - ("unexpected_keys", List[str]), - ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), - ], - ) -): - pass - - -def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: - """Load a model checkpoint with non-strict matching, handling shape mismatches. - - Args: - model (torch.nn.Module): Model to load weights into - checkpoint_state_dict (dict): State dict from checkpoint - - Returns: - _IncompatibleKeys: Named tuple containing: - - missing_keys: Keys present in model but missing from checkpoint - - unexpected_keys: Keys present in checkpoint but not in model - - incorrect_shapes: Keys with mismatched tensor shapes - - The function handles special cases like: - - Uninitialized parameters - - Quantization observers - - TransformerEngine FP8 states - """ - # workaround https://github.com/pytorch/pytorch/issues/24139 - model_state_dict = model.state_dict() - incorrect_shapes = [] - for k in list(checkpoint_state_dict.keys()): - if k in model_state_dict: - if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 - log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") - continue - model_param = model_state_dict[k] - # Allow mismatch for uninitialized parameters - if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): - continue - if not isinstance(model_param, torch.Tensor): - raise ValueError( - f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." - ) - - shape_model = tuple(model_param.shape) - shape_checkpoint = tuple(checkpoint_state_dict[k].shape) - if shape_model != shape_checkpoint: - has_observer_base_classes = ( - TORCH_VERSION >= (1, 8) - and hasattr(quantization, "ObserverBase") - and hasattr(quantization, "FakeQuantizeBase") - ) - if has_observer_base_classes: - # Handle the special case of quantization per channel observers, - # where buffer shape mismatches are expected. - def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: - # foo.bar.param_or_buffer_name -> [foo, bar] - key_parts = key.split(".")[:-1] - cur_module = model - for key_part in key_parts: - cur_module = getattr(cur_module, key_part) - return cur_module - - cls_to_skip = ( - ObserverBase, - FakeQuantizeBase, - ) - target_module = _get_module_for_key(model, k) - if isinstance(target_module, cls_to_skip): - # Do not remove modules with expected shape mismatches - # them from the state_dict loading. They have special logic - # in _load_from_state_dict to handle the mismatches. - continue - - incorrect_shapes.append((k, shape_checkpoint, shape_model)) - checkpoint_state_dict.pop(k) - incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) - # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling - missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] - unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] - return _IncompatibleKeys( - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - incorrect_shapes=incorrect_shapes, - ) - - -@contextmanager -def skip_init_linear(): - # skip init of nn.Linear - orig_reset_parameters = torch.nn.Linear.reset_parameters - torch.nn.Linear.reset_parameters = lambda x: x - xavier_uniform_ = torch.nn.init.xavier_uniform_ - torch.nn.init.xavier_uniform_ = lambda x: x - yield - torch.nn.Linear.reset_parameters = orig_reset_parameters - torch.nn.init.xavier_uniform_ = xavier_uniform_ - - -def load_model_by_config( - config_job_name, - config_file="projects/cosmos_video/config/config.py", - model_class=DiffusionT2WModel, - base_checkpoint_dir="", -): - config_module = get_config_module(config_file) - config = importlib.import_module(config_module).make_config() - config = override(config, ["--", f"experiment={config_job_name}"]) - if base_checkpoint_dir != "" and hasattr(config.model, "base_load_from"): - if hasattr(config.model.base_load_from, "load_path"): - if config.model.base_load_from.load_path != "": - config.model.base_load_from.load_path = config.model.base_load_from.load_path.replace( - "checkpoints", base_checkpoint_dir - ) - log.info( - f"Model need to load a base model weight, change the loading path from default folder to the {base_checkpoint_dir}" - ) - - # Check that the config is valid - config.validate() - # Freeze the config so developers don't change it during training. - config.freeze() # type: ignore - - # Initialize model - with skip_init_linear(): - model = model_class(config.model) - return model - - -def load_network_model(model: DiffusionT2WModel, ckpt_path: str): - if ckpt_path: - with skip_init_linear(): - model.set_up_model() - net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False) # , weights_only=True) - non_strict_load_model(model.model, net_state_dict) - else: - model.set_up_model() - model.cuda() - - -def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str): - with skip_init_linear(): - model.set_up_tokenizer(tokenizer_dir) - model.cuda() - - -def prepare_data_batch( - height: int, - width: int, - num_frames: int, - fps: int, - prompt_embedding: torch.Tensor, - negative_prompt_embedding: Optional[torch.Tensor] = None, -): - """Prepare input batch tensors for video generation. - - Args: - height (int): Height of video frames - width (int): Width of video frames - num_frames (int): Number of frames to generate - fps (int): Frames per second - prompt_embedding (torch.Tensor): Encoded text prompt embeddings - negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings - - Returns: - dict: Batch dictionary containing: - - video: Zero tensor of target video shape - - t5_text_mask: Attention mask for text embeddings - - image_size: Target frame dimensions - - fps: Target frame rate - - num_frames: Number of frames - - padding_mask: Frame padding mask - - t5_text_embeddings: Prompt embeddings - - neg_t5_text_embeddings: Negative prompt embeddings (if provided) - - neg_t5_text_mask: Mask for negative embeddings (if provided) - """ - # Create base data batch - data_batch = { - "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(), - "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), - "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(), - "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(), - "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(), - "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), - } - - # Handle text embeddings - - t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda() - data_batch["t5_text_embeddings"] = t5_embed - - if negative_prompt_embedding is not None: - neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda() - data_batch["neg_t5_text_embeddings"] = neg_t5_embed - data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() - - return data_batch - - -def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames): - """Prepare complete input batch for video generation including latent dimensions. - - Args: - model: Diffusion model instance - prompt_embedding (torch.Tensor): Text prompt embeddings - negative_prompt_embedding (torch.Tensor): Negative prompt embeddings - height (int): Output video height - width (int): Output video width - fps (int): Output video frame rate - num_video_frames (int): Number of frames to generate - - Returns: - tuple: - - data_batch (dict): Complete model input batch - - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression - """ - raw_video_batch = prepare_data_batch( - height=height, - width=width, - num_frames=num_video_frames, - fps=fps, - prompt_embedding=prompt_embedding, - negative_prompt_embedding=negative_prompt_embedding, - ) - state_shape = [ - model.tokenizer.channel, - model.tokenizer.get_latent_num_frames(num_video_frames), - height // model.tokenizer.spatial_compression_factor, - width // model.tokenizer.spatial_compression_factor, - ] - return raw_video_batch, state_shape - - -def resize_video(video_np, h, w, interpolation=cv2.INTER_AREA): - """Resize video frames to the specified height and width.""" - video_np = video_np[0].transpose((1, 2, 3, 0)) # Convert to T x H x W x C - t = video_np.shape[0] - resized_video = np.zeros((t, h, w, 3), dtype=np.uint8) - for i in range(t): - resized_video[i] = cv2.resize(video_np[i], (w, h), interpolation=interpolation) - return resized_video.transpose((3, 0, 1, 2))[None] # Convert back to B x C x T x H x W - - -def detect_aspect_ratio(img_size: tuple[int]): - """Function for detecting the closest aspect ratio.""" - - _aspect_ratios = np.array([(16 / 9), (4 / 3), 1, (3 / 4), (9 / 16)]) - _aspect_ratio_keys = ["16,9", "4,3", "1,1", "3,4", "9,16"] - w, h = img_size - current_ratio = w / h - closest_aspect_ratio = np.argmin((_aspect_ratios - current_ratio) ** 2) - return _aspect_ratio_keys[closest_aspect_ratio] - - -def get_upscale_size(orig_size: tuple[int], aspect_ratio: str, upscale_factor: int = 3, patch_overlap: int = 256): - patch_w, patch_h = orig_size - if aspect_ratio == "16,9" or aspect_ratio == "4,3": - ratio = int(aspect_ratio.split(",")[1]) / int(aspect_ratio.split(",")[0]) - target_w = patch_w * upscale_factor - patch_overlap - target_h = patch_h * upscale_factor - int(patch_overlap * ratio) - elif aspect_ratio == "9,16" or aspect_ratio == "3,4": - ratio = int(aspect_ratio.split(",")[0]) / int(aspect_ratio.split(",")[1]) - target_h = patch_h * upscale_factor - patch_overlap - target_w = patch_w * upscale_factor - int(patch_overlap * ratio) - else: - target_h = patch_h * upscale_factor - patch_overlap - target_w = patch_w * upscale_factor - patch_overlap - return target_w, target_h - - -def read_and_resize_input(input_control_path, num_total_frames, interpolation): - control_input, fps = read_video_or_image_into_frames_BCTHW( - input_control_path, - normalize=False, # s.t. output range is [0, 255] - max_frames=num_total_frames, - also_return_fps=True, - ) # BCTHW - aspect_ratio = detect_aspect_ratio((control_input.shape[-1], control_input.shape[-2])) - w, h = VIDEO_RES_SIZE_INFO[aspect_ratio] - control_input = resize_video(control_input, h, w, interpolation=interpolation) # BCTHW, range [0, 255] - control_input = torch.from_numpy(control_input[0]) # CTHW, range [0, 255] - return control_input, fps, aspect_ratio - - -def get_video_batch_for_multiview_model( - model, prompt_embedding, height, width, fps, num_video_frames, frame_repeat_negative_condition -): - """Prepare complete input batch for video generation including latent dimensions. - - Args: - model: Diffusion model instance - prompt_embedding list(torch.Tensor): Text prompt embeddings - height (int): Output video height - width (int): Output video width - fps (int): Output video frame rate - num_video_frames (int): Number of frames to generate - frame_repeat_negative_condition (int): Number of frames to generate - - Returns: - tuple: - - data_batch (dict): Complete model input batch - - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression - """ - n_views = len(prompt_embedding) - prompt_embedding = einops.rearrange(prompt_embedding, "n t d -> (n t) d").unsqueeze(0) - raw_video_batch = prepare_data_batch( - height=height, - width=width, - num_frames=num_video_frames, - fps=fps, - prompt_embedding=prompt_embedding, - ) - if frame_repeat_negative_condition != -1: - frame_repeat = torch.zeros(n_views) - frame_repeat[-1] = frame_repeat_negative_condition - frame_repeat[-2] = frame_repeat_negative_condition - raw_video_batch["frame_repeat"] = frame_repeat.unsqueeze(0).to(dtype=torch.bfloat16).cuda() - state_shape = [ - model.tokenizer.channel, - model.tokenizer.get_latent_num_frames(int(num_video_frames / n_views)) * n_views, - height // model.tokenizer.spatial_compression_factor, - width // model.tokenizer.spatial_compression_factor, - ] - return raw_video_batch, state_shape - - -def get_ctrl_batch_mv(H, W, data_batch, num_total_frames, control_inputs, num_views, num_video_frames): - # Initialize control input dictionary - control_input_dict = {k: v for k, v in data_batch.items()} - control_weights = [] - hint_keys = [] - for hint_key, control_info in control_inputs.items(): - if hint_key not in valid_hint_keys: - continue - if "input_control" in control_info: - cond_videos = [] - for in_file in control_info["input_control"]: - log.info(f"reading control input {in_file} for hint {hint_key}") - cond_vid, fps = read_video_or_image_into_frames_BCTHW( - in_file, - normalize=False, # s.t. output range is [0, 255] - max_frames=num_total_frames, - also_return_fps=True, - ) - cond_vid = resize_video(cond_vid, H, W, interpolation=cv2.INTER_LINEAR) - cond_vid = torch.from_numpy(cond_vid[0]) - - cond_videos.append(cond_vid) - - input_frames = torch.cat(cond_videos, dim=1) - control_input_dict[f"control_input_{hint_key}"] = input_frames - hint_keys.append(hint_key) - control_weights.append(control_info["control_weight"]) - - target_w, target_h = W, H - hint_key = "control_input_" + "_".join(hint_keys) - add_control_input = get_augmentor_for_eval(input_key="video", output_key=hint_key) - - if len(control_input_dict): - control_input = add_control_input(control_input_dict)[hint_key] - if control_input.ndim == 4: - control_input = control_input[None] - control_input = control_input.bfloat16() / 255 * 2 - 1 - control_weights = load_spatial_temporal_weights( - control_weights, B=1, T=num_total_frames, H=target_h, W=target_w, patch_h=H, patch_w=W - ) - data_batch["control_weight"] = control_weights - - if len(control_inputs) > 1: # Multicontrol enabled - data_batch["hint_key"] = "control_input_multi" - data_batch["control_input_multi"] = control_input - else: # Single-control case - data_batch["hint_key"] = hint_key - data_batch[hint_key] = control_input - - data_batch["target_h"], data_batch["target_w"] = target_h // 8, target_w // 8 - data_batch["video"] = torch.zeros((1, 3, 57, H, W), dtype=torch.uint8).cuda() # ????? - data_batch["image_size"] = torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda() - data_batch["padding_mask"] = torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda() - - # add view indices for post-train model - if num_views == 5: - mapped_indices = [0, 1, 2, 4, 5] - view_indices_conditioning = [] - for v_index in mapped_indices: - view_indices_conditioning.append(torch.ones(num_video_frames, device="cuda") * v_index) - view_indices_conditioning = torch.cat(view_indices_conditioning, dim=0) - data_batch["view_indices"] = view_indices_conditioning.unsqueeze(0).contiguous() - - return data_batch - - -def get_batched_ctrl_batch( - model, - prompt_embeddings, # [B, ...] - negative_prompt_embeddings, # [B, ...] or None - height, - width, - fps, - num_video_frames, - input_video_paths, # List[str], length B - control_inputs_list, # List[dict], length B - blur_strength, - canny_threshold, -): - """ - Create a fully batched data_batch for video generation, including all control and video inputs. - - Args: - model: The diffusion model instance. - prompt_embeddings: [B, ...] tensor of prompt embeddings. - negative_prompt_embeddings: [B, ...] tensor of negative prompt embeddings or None. - height, width, fps, num_video_frames: Video parameters. - input_video_paths: List of input video paths, length B. - control_inputs_list: List of control input dicts, length B. - blur_strength, canny_threshold: ControlNet augmentation parameters. - - Returns: - data_batch: Dict with all fields batched along dim 0 (batch dimension). - state_shape: List describing the latent state shape. - """ - B = len(input_video_paths) - - def prepare_single_data_batch(b): - data_batch = { - "video": torch.zeros((1, 3, num_video_frames, height, width), dtype=torch.uint8).cuda(), - "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(), - "image_size": torch.tensor([[height, width, height, width]], dtype=torch.bfloat16).cuda(), - "fps": torch.tensor([fps], dtype=torch.bfloat16).cuda(), - "num_frames": torch.tensor([num_video_frames], dtype=torch.bfloat16).cuda(), - "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(), - "t5_text_embeddings": prompt_embeddings[b : b + 1].to(dtype=torch.bfloat16).cuda(), - } - if negative_prompt_embeddings is not None: - data_batch["neg_t5_text_embeddings"] = negative_prompt_embeddings[b : b + 1].to(dtype=torch.bfloat16).cuda() - data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda() - return data_batch - - # Prepare and process each sample - single_batches = [] - for b in range(B): - single_data_batch = prepare_single_data_batch(b) - processed = get_ctrl_batch( - model, - single_data_batch, - num_video_frames, - input_video_paths[b], - control_inputs_list[b], - blur_strength, - canny_threshold, - ) - single_batches.append(processed) - - # Merge all single-sample batches into a batched data_batch - batched_data_batch = {} - for k in single_batches[0]: - if isinstance(single_batches[0][k], torch.Tensor): - if k == "control_weight" and single_batches[0][k].ndim == 6: - # [num_controls, 1, 1, T, H, W] per sample - # Stack along dim=1 to get [num_controls, B, 1, T, H, W] - batched_data_batch[k] = torch.cat([d[k] for d in single_batches], dim=1) - else: - # Concatenate along batch dimension (dim=0) for other tensors - batched_data_batch[k] = torch.cat([d[k] for d in single_batches], dim=0) - else: - batched_data_batch[k] = single_batches[0][k] # assume they're the same for now - - state_shape = [ - model.tokenizer.channel, - model.tokenizer.get_latent_num_frames(num_video_frames), - height // model.tokenizer.spatial_compression_factor, - width // model.tokenizer.spatial_compression_factor, - ] - - return batched_data_batch, state_shape - - -def get_ctrl_batch( - model, data_batch, num_video_frames, input_video_path, control_inputs, blur_strength, canny_threshold -): - """Prepare complete input batch for video generation including latent dimensions. - - Args: - model: Diffusion model instance - - Returns: - - data_batch (dict): Complete model input batch - """ - state_shape = model.state_shape - - H, W = ( - state_shape[-2] * model.tokenizer.spatial_compression_factor, - state_shape[-1] * model.tokenizer.spatial_compression_factor, - ) - - # Initialize control input dictionary - control_input_dict = {k: v for k, v in data_batch.items()} - num_total_frames = NUM_MAX_FRAMES - if input_video_path: - input_frames, fps, aspect_ratio = read_and_resize_input( - input_video_path, num_total_frames=num_total_frames, interpolation=cv2.INTER_AREA - ) - _, num_total_frames, H, W = input_frames.shape - control_input_dict["video"] = input_frames.numpy() # CTHW - data_batch["input_video"] = input_frames.bfloat16()[None] / 255 * 2 - 1 # BCTHW - else: - data_batch["input_video"] = None - target_w, target_h = W, H - - control_weights = [] - for hint_key, control_info in control_inputs.items(): - if "input_control" in control_info: - in_file = control_info["input_control"] - interpolation = cv2.INTER_NEAREST if hint_key == "seg" else cv2.INTER_LINEAR - log.info(f"reading control input {in_file} for hint {hint_key}") - control_input_dict[f"control_input_{hint_key}"], fps, aspect_ratio = read_and_resize_input( - in_file, num_total_frames=num_total_frames, interpolation=interpolation - ) # CTHW - num_total_frames = min(num_total_frames, control_input_dict[f"control_input_{hint_key}"].shape[1]) - target_h, target_w = H, W = control_input_dict[f"control_input_{hint_key}"].shape[2:] - if hint_key == "upscale": - orig_size = (W, H) - target_w, target_h = get_upscale_size(orig_size, aspect_ratio, upscale_factor=3) - input_resized = resize_video( - input_frames[None].numpy(), - target_h, - target_w, - interpolation=cv2.INTER_LINEAR, - ) # BCTHW - control_input_dict["control_input_upscale"] = split_video_into_patches( - torch.from_numpy(input_resized), H, W - ) - data_batch["input_video"] = control_input_dict["control_input_upscale"].bfloat16() / 255 * 2 - 1 - control_weights.append(control_info["control_weight"]) - - # Trim all control videos and input video to be the same length. - log.info(f"Making all control and input videos to be length of {num_total_frames} frames.") - if len(control_inputs) > 1: - for hint_key in control_inputs.keys(): - cur_key = f"control_input_{hint_key}" - if cur_key in control_input_dict: - control_input_dict[cur_key] = control_input_dict[cur_key][:, :num_total_frames] - if input_video_path: - control_input_dict["video"] = control_input_dict["video"][:, :num_total_frames] - data_batch["input_video"] = data_batch["input_video"][:, :, :num_total_frames] - - hint_key = "control_input_" + "_".join(control_inputs.keys()) - add_control_input = get_augmentor_for_eval( - input_key="video", - output_key=hint_key, - preset_blur_strength=blur_strength, - preset_canny_threshold=canny_threshold, - blur_config=BilateralOnlyBlurAugmentorConfig[blur_strength], - ) - - if len(control_input_dict): - control_input = add_control_input(control_input_dict)[hint_key] - if control_input.ndim == 4: - control_input = control_input[None] - control_input = control_input.bfloat16() / 255 * 2 - 1 - control_weights = load_spatial_temporal_weights( - control_weights, B=1, T=num_video_frames, H=target_h, W=target_w, patch_h=H, patch_w=W - ) - data_batch["control_weight"] = control_weights - - if len(control_inputs) > 1: # Multicontrol enabled - data_batch["hint_key"] = "control_input_multi" - data_batch["control_input_multi"] = control_input - else: # Single-control case - data_batch["hint_key"] = hint_key - data_batch[hint_key] = control_input - - data_batch["target_h"], data_batch["target_w"] = target_h // 8, target_w // 8 - data_batch["video"] = torch.zeros((1, 3, 121, H, W), dtype=torch.uint8).cuda() - data_batch["image_size"] = torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda() - data_batch["padding_mask"] = torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda() - - return data_batch - - -def generate_control_input(input_file_path, save_folder, hint_key, blur_strength, canny_threshold, num_total_frames=10): - log.info( - f"Generating control input for {hint_key} with blur strength {blur_strength} and canny threshold {canny_threshold}" - ) - video_input = read_video_or_image_into_frames_BCTHW(input_file_path, normalize=False)[0, :, :num_total_frames] - control_input = get_augmentor_for_eval( - input_key="video", - output_key=hint_key, - preset_blur_strength=blur_strength, - preset_canny_threshold=canny_threshold, - blur_config=BilateralOnlyBlurAugmentorConfig[blur_strength], - ) - control_input = control_input({"video": video_input})[hint_key] - control_input = control_input.numpy().transpose((1, 2, 3, 0)) - - output_file_path = f"{save_folder}/{hint_key}_upsampler.mp4" - log.info(f"Saving control input to {output_file_path}") - save_video(frames=control_input, fps=24, filepath=output_file_path) - return output_file_path - - -def generate_world_from_control( - model: DiffusionV2WModel, - state_shape: list[int], - is_negative_prompt: bool, - data_batch: dict, - guidance: float, - num_steps: int, - seed: int, - condition_latent: torch.Tensor, - num_input_frames: int, - sigma_max: float, - x_sigma_max=None, - augment_sigma=None, - use_batch_processing: bool = True, -) -> Tuple[np.array, list, list]: - """Generate video using a conditioning video/image input. - - Args: - model (DiffusionV2WModel): The diffusion model instance - state_shape (list[int]): Shape of the latent state [C,T,H,W] - is_negative_prompt (bool): Whether negative prompt is provided - data_batch (dict): Batch containing model inputs including text embeddings - guidance (float): Classifier-free guidance scale for sampling - num_steps (int): Number of diffusion sampling steps - seed (int): Random seed for generation - condition_latent (torch.Tensor): Latent tensor from conditioning video/image file - num_input_frames (int): Number of input frames - - Returns: - np.array: Generated video frames in shape [T,H,W,C], range [0,255] - """ - assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported" - - if augment_sigma is None: - augment_sigma = DEFAULT_AUGMENT_SIGMA - - b, c, t, h, w = condition_latent.shape - if condition_latent.shape[2] < state_shape[1]: - # Padding condition latent to state shape - condition_latent = torch.cat( - [ - condition_latent, - condition_latent.new_zeros(b, c, state_shape[1] - t, h, w), - ], - dim=2, - ).contiguous() - num_of_latent_condition = compute_num_latent_frames(model, num_input_frames) - - sample = model.generate_samples_from_batch( - data_batch, - guidance=guidance, - state_shape=[c, t, h, w], - num_steps=num_steps, - is_negative_prompt=is_negative_prompt, - seed=seed, - condition_latent=condition_latent, - num_condition_t=num_of_latent_condition, - condition_video_augment_sigma_in_inference=augment_sigma, - x_sigma_max=x_sigma_max, - sigma_max=sigma_max, - target_h=data_batch["target_h"], - target_w=data_batch["target_w"], - patch_h=h, - patch_w=w, - use_batch_processing=use_batch_processing, - ) - return sample - - -def read_video_or_image_into_frames_BCTHW( - input_path: str, - input_path_format: str = "mp4", - H: int = None, - W: int = None, - normalize: bool = True, - max_frames: int = -1, - also_return_fps: bool = False, -) -> torch.Tensor: - """Read video or image file and convert to tensor format. - - Args: - input_path (str): Path to input video/image file - input_path_format (str): Format of input file (default: "mp4") - H (int, optional): Height to resize frames to - W (int, optional): Width to resize frames to - normalize (bool): Whether to normalize pixel values to [-1,1] (default: True) - max_frames (int): Maximum number of frames to read (-1 for all frames) - also_return_fps (bool): Whether to return fps along with frames - - Returns: - torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested - """ - log.debug(f"Reading video from {input_path}") - - loaded_data = load_from_fileobj(input_path, format=input_path_format) - frames, meta_data = loaded_data - if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"): - frames = np.array(frames[0]) # HWC, [0,255] - if frames.shape[-1] > 3: # RGBA, set the transparent to white - # Separate the RGB and Alpha channels - rgb_channels = frames[..., :3] - alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1] - - # Create a white background - white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB - - # Blend the RGB channels with the white background based on the alpha channel - frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype( - np.uint8 - ) - frames = [frames] - fps = 0 - else: - fps = int(meta_data.get("fps")) - if max_frames != -1: - frames = frames[:max_frames] - input_tensor = np.stack(frames, axis=0) - input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w") - if normalize: - input_tensor = input_tensor / 128.0 - 1.0 - input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW - log.debug(f"Raw data shape: {input_tensor.shape}") - if H is not None and W is not None: - input_tensor = transforms_F.resize( - input_tensor, - size=(H, W), # type: ignore - interpolation=transforms_F.InterpolationMode.BICUBIC, - antialias=True, - ) - input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1) - if normalize: - input_tensor = input_tensor.to("cuda") - log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}") - if also_return_fps: - return input_tensor, fps - return input_tensor - - -def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int: - """This function computes the number of latent frames given the number of input frames. - Args: - model (DiffusionV2WModel): video generation model - num_input_frames (int): number of input frames - downsample_factor (int): downsample factor for temporal reduce - Returns: - int: number of latent frames - """ - # First find how many vae chunks are contained with in num_input_frames - num_latent_frames = ( - num_input_frames - // model.tokenizer.video_vae.pixel_chunk_duration - * model.tokenizer.video_vae.latent_chunk_duration - ) - # Then handle the remainder - if num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration == 1: - num_latent_frames += 1 - elif num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration > 1: - assert ( - num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 - ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}" - num_latent_frames += ( - 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor - ) - - return num_latent_frames - - -def create_condition_latent_from_input_frames( - model: DiffusionV2WModel, - input_frames: torch.Tensor, - num_frames_condition: int = 25, - from_back: bool = True, -): - """Create condition latent for video generation from input frames. - - Takes the last num_frames_condition frames from input as conditioning. - - Args: - model (DiffusionV2WModel): Video generation model - input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1] - num_frames_condition (int): Number of frames to use for conditioning - - Returns: - tuple: (condition_latent, encode_input_frames) where: - - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] - - encode_input_frames (torch.Tensor): Padded input frames used for encoding - """ - B, C, T, H, W = input_frames.shape - num_frames_encode = ( - model.tokenizer.pixel_chunk_duration - ) # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1 - log.debug( - f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}" - ) - - log.debug( - f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}" - ) - - assert ( - input_frames.shape[2] >= num_frames_condition - ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}" - assert ( - num_frames_encode >= num_frames_condition - ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}" - - # Put the conditioal frames to the begining of the video, and pad the end with zero - if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": - condition_frames_first = input_frames[:, :, :num_frames_condition] - condition_frames_last = input_frames[:, :, -num_frames_condition:] - padding_frames = condition_frames_first.new_zeros(B, C, num_frames_encode + 1 - 2 * num_frames_condition, H, W) - encode_input_frames = torch.cat([condition_frames_first, padding_frames, condition_frames_last], dim=2) - elif not from_back: - condition_frames = input_frames[:, :, :num_frames_condition] - padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) - encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) - else: - condition_frames = input_frames[:, :, -num_frames_condition:] - padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W) - encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2) - - log.info( - f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end" - ) - if hasattr(model, "n_views") and encode_input_frames.shape[0] == model.n_views: - encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views) - latent = model.encode(encode_input_frames) - elif model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": - latent1 = model.encode(encode_input_frames[:, :, :num_frames_encode]) # BCTHW - latent2 = model.encode(encode_input_frames[:, :, num_frames_encode:]) - latent = torch.cat([latent1, latent2], dim=2) # BCTHW - elif encode_input_frames.shape[0] == 1: - # treat as single view video - latent = model.tokenizer.encode(encode_input_frames) * model.sigma_data - else: - raise ValueError( - f"First dimension of encode_input_frames {encode_input_frames.shape[0]} does not match " - f"model.n_views or model.n_views is not defined and first dimension is not 1" - ) - return latent, encode_input_frames - - -def compute_num_frames_condition(model: DiffusionV2WModel, num_of_latent_overlap: int, downsample_factor=8) -> int: - """This function computes the number of condition pixel frames given the number of latent frames to overlap. - Args: - model (ExtendDiffusionModel): video generation model - num_of_latent_overlap (int): number of latent frames to overlap - downsample_factor (int): downsample factor for temporal reduce - Returns: - int: number of condition frames in output space - """ - if getattr(model.tokenizer.video_vae, "is_casual", True): - # For casual model - num_frames_condition = ( - num_of_latent_overlap - // model.tokenizer.video_vae.latent_chunk_duration - * model.tokenizer.video_vae.pixel_chunk_duration - ) - if num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration == 1: - num_frames_condition += 1 - elif num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration > 1: - num_frames_condition += ( - 1 + (num_of_latent_overlap % model.tokenizer.video_vae.latent_chunk_duration - 1) * downsample_factor - ) - else: - num_frames_condition = num_of_latent_overlap * downsample_factor - - return num_frames_condition - - -def get_condition_latent( - model: DiffusionV2WModel, - input_image_or_video_path: str, - num_input_frames: int = 1, - state_shape: list[int] = None, - frame_index: int = 0, - frame_stride: int = 1, - from_back: bool = True, - start_frame: int = 0, -) -> torch.Tensor: - """Get condition latent from input image/video file. - - Args: - model (DiffusionV2WModel): Video generation model - input_image_or_video_path (str): Path to conditioning image/video - num_input_frames (int): Number of input frames for video2world prediction - - Returns: - tuple: (condition_latent, input_frames) where: - - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W] - - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W] - """ - if state_shape is None: - state_shape = model.state_shape - assert num_input_frames > 0, "num_input_frames must be greater than 0" - - H, W = ( - state_shape[-2] * model.tokenizer.spatial_compression_factor, - state_shape[-1] * model.tokenizer.spatial_compression_factor, - ) - - input_path_format = input_image_or_video_path.split(".")[-1] - input_frames = read_video_or_image_into_frames_BCTHW( - input_image_or_video_path, - input_path_format=input_path_format, - H=H, - W=W, - ) - if model.config.conditioner.video_cond_bool.condition_location == "first_and_last_1": - start_frame = frame_index * frame_stride - end_frame = (frame_index + 1) * frame_stride - curr_input_frames = torch.cat( - [input_frames[:, :, start_frame : start_frame + 1], input_frames[:, :, end_frame : end_frame + 1]], dim=2 - ).contiguous() # BCTHW - num_of_latent_condition = 1 - num_frames_condition = compute_num_frames_condition( - model, num_of_latent_condition, downsample_factor=model.tokenizer.temporal_compression_factor - ) - - condition_latent, _ = create_condition_latent_from_input_frames(model, curr_input_frames, num_frames_condition) - condition_latent = condition_latent.to(torch.bfloat16) - return condition_latent - input_frames = input_frames[:, :, start_frame:, :, :] - condition_latent, _ = create_condition_latent_from_input_frames( - model, input_frames, num_input_frames, from_back=from_back - ) - condition_latent = condition_latent.to(torch.bfloat16) - - return condition_latent - - -def check_input_frames(input_path: str, required_frames: int) -> bool: - """Check if input video/image has sufficient frames. - - Args: - input_path: Path to input video or image - required_frames: Number of required frames - - Returns: - np.ndarray of frames if valid, None if invalid - """ - if input_path.endswith((".jpg", ".jpeg", ".png")): - if required_frames > 1: - log.error(f"Input ({input_path}) is an image but {required_frames} frames are required") - return False - return True # Let the pipeline handle image loading - # For video input - try: - vid = imageio.get_reader(input_path, "ffmpeg") - frame_count = vid.count_frames() - - if frame_count < required_frames: - log.error(f"Input video has {frame_count} frames but {required_frames} frames are required") - return False - else: - return True - except Exception as e: - log.error(f"Error reading video file {input_path}: {e}") - return False - - -def load_spatial_temporal_weights(weight_paths, B, T, H, W, patch_h, patch_w): - """ - Load and process spatial-temporal weight maps from .pt files - Args: - weight_paths: List of weights that can be scalars, paths to .pt files, or empty strings - B, T, H, W: Desired tensor dimensions - patch_h, patch_w: Patch dimensions for splitting - Returns: - For all scalar weights: tensor of shape [num_controls] - For any spatial maps: tensor of shape [num_controls, B, 1, T, H, W] - """ - # Process each weight path - weights = [] - has_spatial_weights = False - for path in weight_paths: - if not path or (isinstance(path, str) and path.lower() == "none"): - # Use default weight of 1.0 - w = torch.ones((T, H, W), dtype=torch.bfloat16) - else: - try: - # Try to parse as scalar - scalar_value = float(path) - w = torch.full((T, H, W), scalar_value, dtype=torch.bfloat16) - except ValueError: - # Not a scalar, must be a path to a weight map - has_spatial_weights = True - w = torch.load(path, weights_only=False).to(dtype=torch.bfloat16) # [T, H, W] - if w.ndim == 2: # Spatial only - w = w.unsqueeze(0).repeat(T, 1, 1) - elif w.ndim != 3: - raise ValueError(f"Weight map must be 2D or 3D, got shape {w.shape}") - - if w.shape != (T, H, W): - w = ( - torch.nn.functional.interpolate( - w.unsqueeze(0).unsqueeze(0), - size=(T, H, W), - mode="trilinear", - align_corners=False, - ) - .squeeze(0) - .squeeze(0) - ) - w = torch.clamp(w, min=0) - w = w.unsqueeze(0).unsqueeze(1) - w = w.expand(B, 1, -1, -1, -1) - weights.append(w) - - if not has_spatial_weights: - scalar_weights = [float(w) for w in weight_paths] - weights_tensor = torch.tensor(scalar_weights, dtype=torch.bfloat16) - weights_tensor = weights_tensor / (weights_tensor.sum().clip(1)) - return weights_tensor.cuda() - - weights = torch.stack(weights, dim=0).cuda() - weights = weights / (weights.sum(dim=0, keepdim=True).clip(1)) - - # Split into patches if needed - if patch_h != H or patch_w != W: - num_controls = len(weights) - weights = weights.reshape(num_controls * B, 1, T, H, W) - weights = split_video_into_patches(weights, patch_h, patch_w) - B_new = weights.shape[0] // num_controls - weights = weights.reshape(num_controls, B_new, 1, T, H, W) - - return weights - - -def resize_control_weight_map(control_weight_map, size): - assert control_weight_map.shape[2] == 1 # [num_control, B, 1, T, H, W] - weight_map = control_weight_map.squeeze(2) # [num_control, B, T, H, W] - T, H, W = size - if weight_map.shape[2:5] != (T, H, W): - assert (weight_map.shape[2] == T) or (weight_map.shape[2] == 8 * (T - 1) + 1) - weight_map_i = [ - torch.nn.functional.interpolate( - weight_map[:, :, :1], - size=(1, H, W), - mode="trilinear", - align_corners=False, - ) - ] - weight_map_i += [ - torch.nn.functional.interpolate( - weight_map[:, :, 1:], - size=(T - 1, H, W), - mode="trilinear", - align_corners=False, - ) - ] - weight_map = torch.cat(weight_map_i, dim=2) - return weight_map.unsqueeze(2) - - -def split_video_into_patches(tensor, patch_h, patch_w): - h, w = tensor.shape[-2:] - n_img_w = (w - 1) // patch_w + 1 - n_img_h = (h - 1) // patch_h + 1 - overlap_size_h = overlap_size_w = 0 - if n_img_w > 1: - overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) # 512 for n=2, 320 for n=4 - assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w - if n_img_h > 1: - overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) - assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h - p_h = patch_h - overlap_size_h - p_w = patch_w - overlap_size_w - - patches = [] - for i in range(n_img_h): - for j in range(n_img_w): - patches += [tensor[:, :, :, p_h * i : (p_h * i + patch_h), p_w * j : (p_w * j + patch_w)]] - return torch.cat(patches) - - -def merge_patches_into_video(imgs, overlap_size_h, overlap_size_w, n_img_h, n_img_w): - b, c, t, h, w = imgs.shape - imgs = rearrange(imgs, "(b m n) c t h w -> m n b c t h w", m=n_img_h, n=n_img_w) - H = n_img_h * h - (n_img_h - 1) * overlap_size_h - W = n_img_w * w - (n_img_w - 1) * overlap_size_w - img_sum = torch.zeros((b // (n_img_h * n_img_w), c, t, H, W)).to(imgs) - mask_sum = torch.zeros((H, W)).to(imgs) - - # Create a linear mask for blending. - def create_linear_gradient_tensor(H, W, overlap_size_h, overlap_size_w): - y, x = torch.meshgrid( - torch.minimum(torch.arange(H), H - torch.arange(H)) / (overlap_size_h + 1e-6), - torch.minimum(torch.arange(W), W - torch.arange(W)) / (overlap_size_w + 1e-6), - ) - return torch.clamp(y, 0.01, 1) * torch.clamp(x, 0.01, 1) - - mask_ij = create_linear_gradient_tensor(h, w, overlap_size_h, overlap_size_w).to(imgs) - - for i in range(n_img_h): - for j in range(n_img_w): - h_start = i * (h - overlap_size_h) - w_start = j * (w - overlap_size_w) - img_sum[:, :, :, h_start : h_start + h, w_start : w_start + w] += ( - imgs[i, j] * mask_ij[None, None, None, :, :] - ) - mask_sum[h_start : h_start + h, w_start : w_start + w] += mask_ij - return img_sum / (mask_sum[None, None, None, :, :] + 1e-6) - - -valid_hint_keys = {"vis", "seg", "edge", "depth", "keypoint", "upscale", "hdmap", "lidar"} - - -def load_controlnet_specs(cfg) -> Dict[str, Any]: - with open(cfg.controlnet_specs, "r") as f: - controlnet_specs_in = json.load(f) - - controlnet_specs = {} - args = {} - - for hint_key, config in controlnet_specs_in.items(): - if hint_key in valid_hint_keys: - controlnet_specs[hint_key] = config - else: - if type(config) == dict: - raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}") - else: - args[hint_key] = config - continue - return controlnet_specs, args - - -def validate_controlnet_specs(cfg, controlnet_specs) -> Dict[str, Any]: - """ - Load and validate controlnet specifications from a JSON file. - - Args: - json_path (str): Path to the JSON file containing controlnet specs. - checkpoint_dir (str): Base directory for checkpoint files. - - Returns: - Dict[str, Any]: Validated and processed controlnet specifications. - """ - checkpoint_dir = cfg.checkpoint_dir - sigma_max = cfg.sigma_max - input_video_path = cfg.input_video_path - use_distilled = cfg.use_distilled - - for hint_key, config in controlnet_specs.items(): - if hint_key not in valid_hint_keys: - raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}") - - if not input_video_path and sigma_max < 80: - raise ValueError("Must have 'input_video' specified if sigma_max < 80") - - if not input_video_path and "input_control" not in config: - raise ValueError( - f"{hint_key} controlnet must have 'input_control' video specified if no 'input_video' specified." - ) - - if "ckpt_path" not in config: - log.info(f"No checkpoint path specified for {hint_key}. Using default.") - ckpt_path = os.path.join(checkpoint_dir, default_model_names[hint_key]) - if use_distilled: - if hint_key in default_distilled_model_names: - ckpt_path = os.path.join(checkpoint_dir, default_distilled_model_names[hint_key]) - else: - log.info(f"No default distilled checkpoint for {hint_key}. Using full checkpoint") - - config["ckpt_path"] = ckpt_path - log.info(f"Using default checkpoint path: {config['ckpt_path']}") - - # Regardless whether "control_weight_prompt" is provided (i.e. whether we automatically - # generate spatiotemporal control weight binary masks), control_weight is needed to. - if "control_weight" not in config: - log.warning(f"No control weight specified for {hint_key}. Setting to 0.5.") - config["control_weight"] = "0.5" - else: - # Check if control weight is a path or a scalar - weight = config["control_weight"] - if not isinstance(weight, str) or not weight.endswith(".pt"): - try: - # Try converting to float - scalar_value = float(weight) - if scalar_value < 0: - raise ValueError(f"Control weight for {hint_key} must be non-negative.") - except ValueError: - raise ValueError( - f"Control weight for {hint_key} must be a valid non-negative float or a path to a .pt file." - ) - - return controlnet_specs diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/preprocessors.py b/cosmos-transfer1/cosmos_transfer1/diffusion/inference/preprocessors.py deleted file mode 100644 index 620d08b24ac8808deabd714ca34df3208d5f6129..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/preprocessors.py +++ /dev/null @@ -1,183 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -import torch - -from cosmos_transfer1.auxiliary.depth_anything.model.depth_anything import DepthAnythingModel -from cosmos_transfer1.auxiliary.human_keypoint.human_keypoint import HumanKeypointModel -from cosmos_transfer1.auxiliary.sam2.sam2_model import VideoSegmentationModel -from cosmos_transfer1.diffusion.inference.inference_utils import valid_hint_keys -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.video_utils import is_valid_video, video_to_tensor - - -class Preprocessors: - def __init__(self): - self.depth_model = None - self.seg_model = None - self.keypoint_model = None - - def __call__(self, input_video, input_prompt, control_inputs, output_folder, regional_prompts=None): - for hint_key in control_inputs: - if hint_key in valid_hint_keys: - if hint_key in ["depth", "seg", "keypoint"]: - self.gen_input_control(input_video, input_prompt, hint_key, control_inputs[hint_key], output_folder) - - # for all hints we need to create weight tensor if not present - control_input = control_inputs[hint_key] - # For each control input modality, compute a spatiotemporal weight tensor as long as - # the user provides "control_weight_prompt". The object specified in the - # control_weight_prompt will be treated as foreground and have control_weight for these locations. - # Everything else will be treated as background and have control weight 0 at those locations. - if control_input.get("control_weight_prompt", None) is not None: - prompt = control_input["control_weight_prompt"] - log.info(f"{hint_key}: generating control weight tensor with SAM using {prompt=}") - out_tensor = os.path.join(output_folder, f"{hint_key}_control_weight.pt") - out_video = os.path.join(output_folder, f"{hint_key}_control_weight.mp4") - weight_scaler = ( - control_input["control_weight"] if isinstance(control_input["control_weight"], float) else 1.0 - ) - self.segmentation( - in_video=input_video, - out_tensor=out_tensor, - out_video=out_video, - prompt=prompt, - weight_scaler=weight_scaler, - binarize_video=True, - ) - control_input["control_weight"] = out_tensor - if regional_prompts and len(regional_prompts): - log.info(f"processing regional prompts: {regional_prompts}") - for i, regional_prompt in enumerate(regional_prompts): - log.info(f"generating regional context for {regional_prompt}") - out_tensor = os.path.join(output_folder, f"regional_context_{i}.pt") - if "mask_prompt" in regional_prompt: - prompt = regional_prompt["mask_prompt"] - out_video = os.path.join(output_folder, f"regional_context_{i}.mp4") - self.segmentation( - in_video=input_video, - out_tensor=out_tensor, - out_video=out_video, - prompt=prompt, - weight_scaler=1.0, - legacy_mask=True, - ) - if os.path.exists(out_tensor): - regional_prompt["region_definitions_path"] = out_tensor - elif "region_definitions_path" in regional_prompt and isinstance( - regional_prompt["region_definitions_path"], str - ): - if is_valid_video(regional_prompt["region_definitions_path"]): - log.info(f"converting video to tensor: {regional_prompt['region_definitions_path']}") - video_to_tensor(regional_prompt["region_definitions_path"], out_tensor) - regional_prompt["region_definitions_path"] = out_tensor - else: - raise ValueError(f"Invalid video file: {regional_prompt['region_definitions_path']}") - else: - log.info("do nothing!") - - return control_inputs - - def gen_input_control(self, in_video, in_prompt, hint_key, control_input, output_folder): - # if input control isn't provided we need to run preprocessor to create input control tensor - # for depth no special params, for SAM we need to run with prompt - if control_input.get("input_control", None) is None: - out_video = os.path.join(output_folder, f"{hint_key}_input_control.mp4") - control_input["input_control"] = out_video - if hint_key == "seg": - prompt = control_input.get("input_control_prompt", in_prompt) - prompt = " ".join(prompt.split()[:128]) - log.info( - f"no input_control provided for {hint_key}. generating input control video with SAM using {prompt=}" - ) - self.segmentation( - in_video=in_video, - out_video=out_video, - prompt=prompt, - ) - elif hint_key == "depth": - log.info( - f"no input_control provided for {hint_key}. generating input control video with DepthAnythingModel" - ) - self.depth( - in_video=in_video, - out_video=out_video, - ) - else: - log.info(f"no input_control provided for {hint_key}. generating input control video with Openpose") - self.keypoint( - in_video=in_video, - out_video=out_video, - ) - - def depth(self, in_video, out_video): - if self.depth_model is None: - self.depth_model = DepthAnythingModel() - - self.depth_model(in_video, out_video) - - def keypoint(self, in_video, out_video): - if self.keypoint_model is None: - self.keypoint_model = HumanKeypointModel() - - self.keypoint_model(in_video, out_video) - - def segmentation( - self, - in_video, - prompt, - out_video=None, - out_tensor=None, - weight_scaler=None, - binarize_video=False, - legacy_mask=False, - ): - if self.seg_model is None: - self.seg_model = VideoSegmentationModel() - self.seg_model( - input_video=in_video, - output_video=out_video, - output_tensor=out_tensor, - prompt=prompt, - weight_scaler=weight_scaler, - binarize_video=binarize_video, - legacy_mask=legacy_mask, - ) - - -if __name__ == "__main__": - control_inputs = dict( - { - "depth": { - # "input_control": "depth_control_input.mp4", # if empty we need to run depth - # "control_weight" : "0.1", # if empty we need to run SAM - "control_weight_prompt": "a boy", # SAM weights prompt - }, - "seg": { - # "input_control": "seg_control_input.mp4", # if empty we need to run SAM - "input_control_prompt": "A boy", - "control_weight_prompt": "A boy", # if present we need to generate weight tensor - }, - }, - ) - - preprocessor = Preprocessors() - input_video = "cosmos_transfer1/models/sam2/assets/input_video.mp4" - - preprocessor(input_video, control_inputs) - print(json.dumps(control_inputs, indent=4)) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/transfer.py b/cosmos-transfer1/cosmos_transfer1/diffusion/inference/transfer.py deleted file mode 100644 index a39d68b2a2f8067bccd8ffd837aa9280c96911dd..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/transfer.py +++ /dev/null @@ -1,387 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import copy -import json -import os - -os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning - -import sys -from io import BytesIO - -import torch - -from cosmos_transfer1.checkpoints import ( - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH, - BASE_7B_CHECKPOINT_PATH, - EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, -) -from cosmos_transfer1.diffusion.inference.inference_utils import load_controlnet_specs, validate_controlnet_specs -from cosmos_transfer1.diffusion.inference.preprocessors import Preprocessors -from cosmos_transfer1.diffusion.inference.world_generation_pipeline import ( - DiffusionControl2WorldGenerationPipeline, - DistilledControl2WorldGenerationPipeline, -) -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.io import read_prompts_from_file, save_video - -torch.enable_grad(False) -torch.serialization.add_safe_globals([BytesIO]) - - -def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Control to world generation demo script", conflict_handler="resolve") - - # Add transfer specific arguments - parser.add_argument( - "--prompt", - type=str, - default="The video captures a stunning, photorealistic scene with remarkable attention to detail, giving it a lifelike appearance that is almost indistinguishable from reality. It appears to be from a high-budget 4K movie, showcasing ultra-high-definition quality with impeccable resolution.", - help="prompt which the sampled video condition on", - ) - parser.add_argument( - "--negative_prompt", - type=str, - default="The video captures a game playing, with bad crappy graphics and cartoonish frames. It represents a recording of old outdated games. The lighting looks very fake. The textures are very raw and basic. The geometries are very primitive. The images are very pixelated and of poor CG quality. There are many subtitles in the footage. Overall, the video is unrealistic at all.", - help="negative prompt which the sampled video condition on", - ) - parser.add_argument( - "--input_video_path", - type=str, - default="", - help="Optional input RGB video path", - ) - parser.add_argument( - "--num_input_frames", - type=int, - default=1, - help="Number of conditional frames for long video generation", - choices=[1], - ) - parser.add_argument("--sigma_max", type=float, default=70.0, help="sigma_max for partial denoising") - parser.add_argument( - "--blur_strength", - type=str, - default="medium", - choices=["very_low", "low", "medium", "high", "very_high"], - help="blur strength.", - ) - parser.add_argument( - "--canny_threshold", - type=str, - default="medium", - choices=["very_low", "low", "medium", "high", "very_high"], - help="blur strength of canny threshold applied to input. Lower means less blur or more detected edges, which means higher fidelity to input.", - ) - parser.add_argument( - "--controlnet_specs", - type=str, - help="Path to JSON file specifying multicontrolnet configurations", - required=True, - ) - parser.add_argument( - "--is_av_sample", action="store_true", help="Whether the model is an driving post-training model" - ) - parser.add_argument( - "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" - ) - parser.add_argument( - "--tokenizer_dir", - type=str, - default="Cosmos-Tokenize1-CV8x8x8-720p", - help="Tokenizer weights directory relative to checkpoint_dir", - ) - parser.add_argument( - "--video_save_name", - type=str, - default="output", - help="Output filename for generating a single video", - ) - parser.add_argument( - "--video_save_folder", - type=str, - default="outputs/", - help="Output folder for generating a batch of videos", - ) - parser.add_argument( - "--batch_input_path", - type=str, - help="Path to a JSONL file of input prompts for generating a batch of videos", - ) - parser.add_argument("--batch_size", type=int, default=1, help="Batch size") - parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") - parser.add_argument("--guidance", type=float, default=5, help="Classifier-free guidance scale value") - parser.add_argument("--fps", type=int, default=24, help="FPS of the output video") - parser.add_argument("--seed", type=int, default=1, help="Random seed") - parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") - parser.add_argument( - "--offload_diffusion_transformer", - action="store_true", - help="Offload DiT after inference", - ) - parser.add_argument( - "--offload_text_encoder_model", - action="store_true", - help="Offload text encoder model after inference", - ) - parser.add_argument( - "--offload_guardrail_models", - action="store_true", - help="Offload guardrail models after inference", - ) - parser.add_argument( - "--upsample_prompt", - action="store_true", - help="Upsample prompt using Pixtral upsampler model", - ) - parser.add_argument( - "--offload_prompt_upsampler", - action="store_true", - help="Offload prompt upsampler model after inference", - ) - parser.add_argument("--use_distilled", action="store_true", help="Use distilled ControlNet model variant") - - cmd_args = parser.parse_args() - - # Load and parse JSON input - control_inputs, json_args = load_controlnet_specs(cmd_args) - - log.info(f"control_inputs: {json.dumps(control_inputs, indent=4)}") - log.info(f"args in json: {json.dumps(json_args, indent=4)}") - # if parameters not set on command line, use the ones from the controlnet_specs - # if both not set use command line defaults - for key in json_args: - if f"--{key}" not in sys.argv: - setattr(cmd_args, key, json_args[key]) - - log.info(f"final args: {json.dumps(vars(cmd_args), indent=4)}") - - return cmd_args, control_inputs - - -def demo(cfg, control_inputs): - """Run control-to-world generation demo. - - This function handles the main control-to-world generation pipeline, including: - - Setting up the random seed for reproducibility - - Initializing the generation pipeline with the provided configuration - - Processing single or multiple prompts/images/videos from input - - Generating videos from prompts and images/videos - - Saving the generated videos and corresponding prompts to disk - - Args: - cfg (argparse.Namespace): Configuration namespace containing: - - Model configuration (checkpoint paths, model settings) - - Generation parameters (guidance, steps, dimensions) - - Input/output settings (prompts/images/videos, save paths) - - Performance options (model offloading settings) - - The function will save: - - Generated MP4 video files - - Text files containing the processed prompts - - If guardrails block the generation, a critical log message is displayed - and the function continues to the next prompt if available. - """ - - control_inputs = validate_controlnet_specs(cfg, control_inputs) - misc.set_random_seed(cfg.seed) - - device_rank = 0 - process_group = None - if cfg.num_gpus > 1: - from megatron.core import parallel_state - - from cosmos_transfer1.utils import distributed - - distributed.init() - parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus) - process_group = parallel_state.get_context_parallel_group() - - device_rank = distributed.get_rank(process_group) - - preprocessors = Preprocessors() - - if cfg.use_distilled: - assert not cfg.is_av_sample - checkpoint = EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH - pipeline = DistilledControl2WorldGenerationPipeline( - checkpoint_dir=cfg.checkpoint_dir, - checkpoint_name=checkpoint, - offload_network=cfg.offload_diffusion_transformer, - offload_text_encoder_model=cfg.offload_text_encoder_model, - offload_guardrail_models=cfg.offload_guardrail_models, - guidance=cfg.guidance, - num_steps=cfg.num_steps, - fps=cfg.fps, - seed=cfg.seed, - num_input_frames=cfg.num_input_frames, - control_inputs=control_inputs, - sigma_max=cfg.sigma_max, - blur_strength=cfg.blur_strength, - canny_threshold=cfg.canny_threshold, - upsample_prompt=cfg.upsample_prompt, - offload_prompt_upsampler=cfg.offload_prompt_upsampler, - process_group=process_group, - ) - else: - checkpoint = BASE_7B_CHECKPOINT_AV_SAMPLE_PATH if cfg.is_av_sample else BASE_7B_CHECKPOINT_PATH - - # Initialize transfer generation model pipeline - pipeline = DiffusionControl2WorldGenerationPipeline( - checkpoint_dir=cfg.checkpoint_dir, - checkpoint_name=checkpoint, - offload_network=cfg.offload_diffusion_transformer, - offload_text_encoder_model=cfg.offload_text_encoder_model, - offload_guardrail_models=cfg.offload_guardrail_models, - guidance=cfg.guidance, - num_steps=cfg.num_steps, - fps=cfg.fps, - seed=cfg.seed, - num_input_frames=cfg.num_input_frames, - control_inputs=control_inputs, - sigma_max=cfg.sigma_max, - blur_strength=cfg.blur_strength, - canny_threshold=cfg.canny_threshold, - upsample_prompt=cfg.upsample_prompt, - offload_prompt_upsampler=cfg.offload_prompt_upsampler, - process_group=process_group, - ) - - if cfg.batch_input_path: - log.info(f"Reading batch inputs from path: {cfg.batch_input_path}") - prompts = read_prompts_from_file(cfg.batch_input_path) - else: - # Single prompt case - prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_video_path}] - - batch_size = cfg.batch_size if hasattr(cfg, "batch_size") else 1 - if any("upscale" in control_input for control_input in control_inputs) and batch_size > 1: - batch_size = 1 - log.info("Setting batch_size=1 as upscale does not support batch generation") - os.makedirs(cfg.video_save_folder, exist_ok=True) - for batch_start in range(0, len(prompts), batch_size): - # Get current batch - batch_prompts = prompts[batch_start : batch_start + batch_size] - actual_batch_size = len(batch_prompts) - # Extract batch data - batch_prompt_texts = [p.get("prompt", None) for p in batch_prompts] - batch_video_paths = [p.get("visual_input", None) for p in batch_prompts] - - batch_control_inputs = [] - for i, input_dict in enumerate(batch_prompts): - current_prompt = input_dict.get("prompt", None) - current_video_path = input_dict.get("visual_input", None) - - if cfg.batch_input_path: - video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}") - os.makedirs(video_save_subfolder, exist_ok=True) - else: - video_save_subfolder = cfg.video_save_folder - - current_control_inputs = copy.deepcopy(control_inputs) - if "control_overrides" in input_dict: - for hint_key, override in input_dict["control_overrides"].items(): - if hint_key in current_control_inputs: - current_control_inputs[hint_key].update(override) - else: - log.warning(f"Ignoring unknown control key in override: {hint_key}") - - # if control inputs are not provided, run respective preprocessor (for seg and depth) - log.info("running preprocessor") - preprocessors( - current_video_path, - current_prompt, - current_control_inputs, - video_save_subfolder, - cfg.regional_prompts if hasattr(cfg, "regional_prompts") else None, - ) - batch_control_inputs.append(current_control_inputs) - - regional_prompts = [] - region_definitions = [] - if hasattr(cfg, "regional_prompts") and cfg.regional_prompts: - log.info(f"regional_prompts: {cfg.regional_prompts}") - for regional_prompt in cfg.regional_prompts: - regional_prompts.append(regional_prompt["prompt"]) - if "region_definitions_path" in regional_prompt: - log.info(f"region_definitions_path: {regional_prompt['region_definitions_path']}") - region_definition_path = regional_prompt["region_definitions_path"] - if isinstance(region_definition_path, str) and region_definition_path.endswith(".json"): - with open(region_definition_path, "r") as f: - region_definitions_json = json.load(f) - region_definitions.extend(region_definitions_json) - else: - region_definitions.append(region_definition_path) - - if hasattr(pipeline, "regional_prompts"): - pipeline.regional_prompts = regional_prompts - if hasattr(pipeline, "region_definitions"): - pipeline.region_definitions = region_definitions - - # Generate videos in batch - batch_outputs = pipeline.generate( - prompt=batch_prompt_texts, - video_path=batch_video_paths, - negative_prompt=cfg.negative_prompt, - control_inputs=batch_control_inputs, - save_folder=video_save_subfolder, - batch_size=actual_batch_size, - ) - if batch_outputs is None: - log.critical("Guardrail blocked generation for entire batch.") - continue - - videos, final_prompts = batch_outputs - for i, (video, prompt) in enumerate(zip(videos, final_prompts)): - if cfg.batch_input_path: - video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}") - video_save_path = os.path.join(video_save_subfolder, "output.mp4") - prompt_save_path = os.path.join(video_save_subfolder, "prompt.txt") - else: - video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4") - prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt") - # Save video and prompt - if device_rank == 0: - os.makedirs(os.path.dirname(video_save_path), exist_ok=True) - save_video( - video=video, - fps=cfg.fps, - H=video.shape[1], - W=video.shape[2], - video_save_quality=5, - video_save_path=video_save_path, - ) - - # Save prompt to text file alongside video - with open(prompt_save_path, "wb") as f: - f.write(prompt.encode("utf-8")) - - log.info(f"Saved video to {video_save_path}") - log.info(f"Saved prompt to {prompt_save_path}") - - # clean up properly - if cfg.num_gpus > 1: - parallel_state.destroy_model_parallel() - import torch.distributed as dist - - dist.destroy_process_group() - - -if __name__ == "__main__": - args, control_inputs = parse_arguments() - demo(args, control_inputs) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/transfer_multiview.py b/cosmos-transfer1/cosmos_transfer1/diffusion/inference/transfer_multiview.py deleted file mode 100644 index 8c06948aeffce024d8304efa8ba398aa35d8b55f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/transfer_multiview.py +++ /dev/null @@ -1,456 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import copy -import json -import os - -os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning - -import sys -from io import BytesIO - -import torch - -from cosmos_transfer1.checkpoints import ( - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, -) -from cosmos_transfer1.diffusion.inference.inference_utils import ( - default_model_names, - load_controlnet_specs, - valid_hint_keys, -) -from cosmos_transfer1.diffusion.inference.preprocessors import Preprocessors -from cosmos_transfer1.diffusion.inference.world_generation_pipeline import ( - DiffusionControl2WorldMultiviewGenerationPipeline, -) -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.io import save_video - -torch.enable_grad(False) - -from cosmos_transfer1.checkpoints import ( - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH, - BASE_7B_CHECKPOINT_PATH, - DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, - VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, - SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, -) -from cosmos_transfer1.diffusion.model.model_ctrl import VideoDiffusionModelWithCtrl, VideoDiffusionT2VModelWithCtrl -from cosmos_transfer1.diffusion.model.model_multi_camera_ctrl import MultiVideoDiffusionModelWithCtrl - -MODEL_CLASS_DICT = { - BASE_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: VideoDiffusionT2VModelWithCtrl, - HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, - LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, -} - -MODEL_NAME_DICT = { - BASE_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", - EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", - VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_vis_block3", - DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3", - KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_keypoint_block3", - SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3", - UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_upscale_block3", - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", - HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", - LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_lidar_block3", - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_block3", - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_lidar_block3", - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_waymo_block3", - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_waymo_block3", -} - - -def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Control to world generation demo script", conflict_handler="resolve") - parser.add_argument( - "--prompt", - type=str, - default="The video captures a stunning, photorealistic scene with remarkable attention to detail, giving it a lifelike appearance that is almost indistinguishable from reality. It appears to be from a high-budget 4K movie, showcasing ultra-high-definition quality with impeccable resolution.", - help="prompt which the sampled video condition on", - ) - parser.add_argument( - "--prompt_left", - type=str, - default="The video is captured from a camera mounted on a car. The camera is facing to the left. ", - help="Text prompt for generating left camera view video", - ) - parser.add_argument( - "--prompt_right", - type=str, - default="The video is captured from a camera mounted on a car. The camera is facing to the right.", - help="Text prompt for generating right camera view video", - ) - - parser.add_argument( - "--prompt_back", - type=str, - default="The video is captured from a camera mounted on a car. The camera is facing backwards.", - help="Text prompt for generating rear camera view video", - ) - parser.add_argument( - "--prompt_back_left", - type=str, - default="The video is captured from a camera mounted on a car. The camera is facing the rear left side.", - help="Text prompt for generating left camera view video", - ) - parser.add_argument( - "--prompt_back_right", - type=str, - default="The video is captured from a camera mounted on a car. The camera is facing the rear right side.", - help="Text prompt for generating right camera view video", - ) - parser.add_argument( - "--view_condition_video", - type=str, - default="", - help="We require that only a single condition view is specified and this video is treated as conditioning for that view. " - "This video/videos should have the same duration as control videos", - ) - parser.add_argument( - "--initial_condition_video", - type=str, - default="", - help="Can be either a path to a mp4 or a directory. If it is a mp4, we assume" - "that it is a video temporally concatenated with the same number of views as the model. " - "If it is a directory, we assume that the file names evaluate to integers that correspond to a view index," - " e.g. '000.mp4', '003.mp4', '004.mp4'." - "This video/videos should have at least num_input_frames number of frames for each view. Frames will be taken from the back" - "of the video(s) if the duration of the video in each view exceed num_input_frames", - ) - parser.add_argument( - "--num_input_frames", - type=int, - default=1, - help="Number of conditional frames for long video generation, not used in t2w", - choices=[1, 9], - ) - parser.add_argument( - "--controlnet_specs", - type=str, - help="Path to JSON file specifying multicontrolnet configurations", - required=True, - ) - parser.add_argument( - "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" - ) - parser.add_argument( - "--tokenizer_dir", - type=str, - default="Cosmos-Tokenize1-CV8x8x8-720p", - help="Tokenizer weights directory relative to checkpoint_dir", - ) - parser.add_argument( - "--video_save_name", - type=str, - default="output", - help="Output filename for generating a single video", - ) - parser.add_argument( - "--video_save_folder", - type=str, - default="outputs/", - help="Output folder for generating a batch of videos", - ) - parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps") - parser.add_argument("--guidance", type=float, default=5, help="Classifier-free guidance scale value") - parser.add_argument("--fps", type=int, default=24, help="FPS of the output video") - parser.add_argument("--seed", type=int, default=1, help="Random seed") - parser.add_argument("--n_clip_max", type=int, default=-1, help="Maximum number of video extension loop") - parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs used to run inference in parallel.") - parser.add_argument( - "--offload_diffusion_transformer", - action="store_true", - help="Offload DiT after inference", - ) - parser.add_argument( - "--offload_text_encoder_model", - action="store_true", - help="Offload text encoder model after inference", - ) - parser.add_argument( - "--offload_guardrail_models", - action="store_true", - help="Offload guardrail models after inference", - ) - parser.add_argument( - "--upsample_prompt", - action="store_true", - help="Upsample prompt using Pixtral upsampler model", - ) - parser.add_argument( - "--offload_prompt_upsampler", - action="store_true", - help="Offload prompt upsampler model after inference", - ) - parser.add_argument( - "--waymo_example", - type=bool, - default=False, - help="Set to true when using post-trained checkpoint from the Waymo post-training example", - ) - - cmd_args = parser.parse_args() - - # Load and parse JSON input - control_inputs, json_args = load_controlnet_specs(cmd_args) - control_inputs.update(json_args) - log.info(f"control_inputs: {json.dumps(control_inputs, indent=4)}") - log.info(f"args in json: {json.dumps(json_args, indent=4)}") - - # if parameters not set on command line, use the ones from the controlnet_specs - # if both not set use command line defaults - for key in json_args: - if f"--{key}" not in sys.argv: - setattr(cmd_args, key, json_args[key]) - - log.info(f"final args: {json.dumps(vars(cmd_args), indent=4)}") - - return cmd_args, control_inputs - - -def validate_controlnet_specs(cfg, controlnet_specs): - """ - Load and validate controlnet specifications from a JSON file. - - Args: - json_path (str): Path to the JSON file containing controlnet specs. - checkpoint_dir (str): Base directory for checkpoint files. - - Returns: - Dict[str, Any]: Validated and processed controlnet specifications. - """ - checkpoint_dir = cfg.checkpoint_dir - - for hint_key, config in controlnet_specs.items(): - if hint_key not in list(valid_hint_keys) + ["prompts", "view_condition_video"]: - raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}") - if hint_key in valid_hint_keys: - if "ckpt_path" not in config: - log.info(f"No checkpoint path specified for {hint_key}. Using default.") - config["ckpt_path"] = os.path.join(checkpoint_dir, default_model_names[hint_key]) - - # Regardless whether "control_weight_prompt" is provided (i.e. whether we automatically - # generate spatiotemporal control weight binary masks), control_weight is needed to. - if "control_weight" not in config: - log.warning(f"No control weight specified for {hint_key}. Setting to 0.5.") - config["control_weight"] = "0.5" - else: - # Check if control weight is a path or a scalar - weight = config["control_weight"] - if not isinstance(weight, str) or not weight.endswith(".pt"): - try: - # Try converting to float - scalar_value = float(weight) - if scalar_value < 0: - raise ValueError(f"Control weight for {hint_key} must be non-negative.") - except ValueError: - raise ValueError( - f"Control weight for {hint_key} must be a valid non-negative float or a path to a .pt file." - ) - - return controlnet_specs - - -def demo(cfg, control_inputs): - """Run control-to-world generation demo. - - This function handles the main control-to-world generation pipeline, including: - - Setting up the random seed for reproducibility - - Initializing the generation pipeline with the provided configuration - - Processing single or multiple prompts/images/videos from input - - Generating videos from prompts and images/videos - - Saving the generated videos and corresponding prompts to disk - - Args: - cfg (argparse.Namespace): Configuration namespace containing: - - Model configuration (checkpoint paths, model settings) - - Generation parameters (guidance, steps, dimensions) - - Input/output settings (prompts/images/videos, save paths) - - Performance options (model offloading settings) - - The function will save: - - Generated MP4 video files - - Text files containing the processed prompts - - If guardrails block the generation, a critical log message is displayed - and the function continues to the next prompt if available. - """ - - control_inputs = validate_controlnet_specs(cfg, control_inputs) - misc.set_random_seed(cfg.seed) - - device_rank = 0 - process_group = None - if cfg.num_gpus > 1: - from megatron.core import parallel_state - - from cosmos_transfer1.utils import distributed - - distributed.init() - parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus) - process_group = parallel_state.get_context_parallel_group() - - device_rank = distributed.get_rank(process_group) - - preprocessors = Preprocessors() - - if cfg.waymo_example: - prompts = [ - cfg.prompt, - cfg.prompt_left, - cfg.prompt_right, - cfg.prompt_back_left, - cfg.prompt_back_right, - ] - if cfg.initial_condition_video: - cfg.is_lvg_model = True - checkpoint = SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH - else: - cfg.is_lvg_model = False - cfg.num_input_frames = 0 - checkpoint = SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH - - else: - prompts = [ - cfg.prompt, - cfg.prompt_left, - cfg.prompt_right, - cfg.prompt_back, - cfg.prompt_back_left, - cfg.prompt_back_right, - ] - - if cfg.initial_condition_video: - cfg.is_lvg_model = True - checkpoint = BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH - else: - cfg.is_lvg_model = False - cfg.num_input_frames = 0 - checkpoint = BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH - - # Initialize transfer generation model pipeline - pipeline = DiffusionControl2WorldMultiviewGenerationPipeline( - checkpoint_dir=cfg.checkpoint_dir, - checkpoint_name=checkpoint, - offload_network=cfg.offload_diffusion_transformer, - offload_text_encoder_model=cfg.offload_text_encoder_model, - offload_guardrail_models=cfg.offload_guardrail_models, - guidance=cfg.guidance, - num_steps=cfg.num_steps, - fps=cfg.fps, - seed=cfg.seed, - num_input_frames=cfg.num_input_frames, - control_inputs=control_inputs, - sigma_max=80.0, - num_video_frames=57, - process_group=process_group, - height=576, - width=1024, - is_lvg_model=cfg.is_lvg_model, - n_clip_max=cfg.n_clip_max, - waymo_example=cfg.waymo_example, - ) - - os.makedirs(cfg.video_save_folder, exist_ok=True) - - current_prompt = prompts - current_video_path = "" - video_save_subfolder = os.path.join(cfg.video_save_folder, "video_0") - os.makedirs(video_save_subfolder, exist_ok=True) - current_control_inputs = copy.deepcopy(control_inputs) - - # if control inputs are not provided, run respective preprocessor (for seg and depth) - preprocessors(current_video_path, current_prompt, current_control_inputs, video_save_subfolder) - - # Generate video - generated_output = pipeline.generate( - prompts=current_prompt, - view_condition_video=cfg.view_condition_video, - initial_condition_video=cfg.initial_condition_video, - control_inputs=current_control_inputs, - save_folder=video_save_subfolder, - ) - if generated_output is None: - log.critical("Guardrail blocked generation.") - video, prompt = generated_output - - video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4") - prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt") - - if device_rank == 0: - # Save video - os.makedirs(os.path.dirname(video_save_path), exist_ok=True) - save_video( - video=video, - fps=cfg.fps, - H=video.shape[1], - W=video.shape[2], - video_save_quality=7, - video_save_path=video_save_path, - ) - - # Save prompt to text file alongside video - with open(prompt_save_path, "wb") as f: - f.write(";".join(prompt).encode("utf-8")) - - log.info(f"Saved video to {video_save_path}") - log.info(f"Saved prompt to {prompt_save_path}") - - # clean up properly - if cfg.num_gpus > 1: - parallel_state.destroy_model_parallel() - import torch.distributed as dist - - dist.destroy_process_group() - - -if __name__ == "__main__": - args, control_inputs = parse_arguments() - demo(args, control_inputs) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/world_generation_pipeline.py b/cosmos-transfer1/cosmos_transfer1/diffusion/inference/world_generation_pipeline.py deleted file mode 100644 index 0c816a9629fabc8c51962baed9d42db03d99961f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/inference/world_generation_pipeline.py +++ /dev/null @@ -1,1427 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from collections import defaultdict -from typing import List, Optional, Union - -import cv2 -import einops -import numpy as np -import torch -from tqdm import tqdm - -from cosmos_transfer1.auxiliary.upsampler.model.upsampler import PixtralPromptUpsampler -from cosmos_transfer1.checkpoints import ( - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH, - BASE_7B_CHECKPOINT_PATH, - COSMOS_TOKENIZER_CHECKPOINT, - DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH, - HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH, - VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH, - SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH, -) -from cosmos_transfer1.diffusion.inference.inference_utils import ( - detect_aspect_ratio, - generate_control_input, - generate_world_from_control, - get_batched_ctrl_batch, - get_ctrl_batch, - get_ctrl_batch_mv, - get_upscale_size, - get_video_batch, - get_video_batch_for_multiview_model, - load_model_by_config, - load_network_model, - load_tokenizer_model, - merge_patches_into_video, - non_strict_load_model, - read_and_resize_input, - read_video_or_image_into_frames_BCTHW, - resize_control_weight_map, - resize_video, - split_video_into_patches, - valid_hint_keys, -) -from cosmos_transfer1.diffusion.model.model_ctrl import ( - VideoDiffusionModelWithCtrl, - VideoDiffusionT2VModelWithCtrl, - VideoDistillModelWithCtrl, -) -from cosmos_transfer1.diffusion.model.model_multi_camera_ctrl import MultiVideoDiffusionModelWithCtrl -from cosmos_transfer1.diffusion.module.parallel import broadcast -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.base_world_generation_pipeline import BaseWorldGenerationPipeline -from cosmos_transfer1.utils.regional_prompting_utils import prepare_regional_prompts - -MODEL_NAME_DICT = { - BASE_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", - EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3", - VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_vis_block3", - DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3", - KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_keypoint_block3", - SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3", - UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_lvg_tp_121frames_control_input_upscale_block3", - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", - HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_hdmap_block3", - LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_t2v_121frames_control_input_lidar_block3", - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_block3", - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_block3", - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_lidar_block3", - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_t2w_57frames_control_input_hdmap_waymo_block3", - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: "CTRL_7Bv1pt3_sv2mv_v2w_57frames_control_input_hdmap_waymo_block3", - EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH: "dev_v2w_ctrl_7bv1pt3_VisControlCanny_video_only_dmd2_fsdp", -} -MODEL_CLASS_DICT = { - BASE_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionModelWithCtrl, - BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: VideoDiffusionT2VModelWithCtrl, - HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, - LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: VideoDiffusionT2VModelWithCtrl, - BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_WAYMO_CHECKPOINT_PATH: MultiVideoDiffusionModelWithCtrl, - EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH: VideoDistillModelWithCtrl, -} - -from collections import defaultdict - - -class DiffusionControl2WorldGenerationPipeline(BaseWorldGenerationPipeline): - def __init__( - self, - checkpoint_dir: str, - checkpoint_name: str, - has_text_input: bool = True, - offload_network: bool = False, - offload_tokenizer: bool = False, - offload_text_encoder_model: bool = False, - offload_guardrail_models: bool = False, - guidance: float = 7.0, - num_steps: int = 35, - height: int = 704, - width: int = 1280, - fps: int = 24, - num_video_frames: int = 121, - seed: int = 0, - num_input_frames: int = 1, - control_inputs: dict = None, - sigma_max: float = 70.0, - blur_strength: str = "medium", - canny_threshold: str = "medium", - upsample_prompt: bool = False, - offload_prompt_upsampler: bool = False, - process_group: torch.distributed.ProcessGroup | None = None, - regional_prompts: List[str] = None, - region_definitions: Union[List[List[float]], torch.Tensor] = None, - waymo_example: bool = False, - ): - """Initialize diffusion world generation pipeline. - - Args: - checkpoint_dir: Base directory containing model checkpoints - checkpoint_name: Name of the diffusion transformer checkpoint to use - has_text_input: Whether the pipeline takes text input for world generation - offload_network: Whether to offload diffusion transformer after inference - offload_tokenizer: Whether to offload tokenizer after inference - offload_text_encoder_model: Whether to offload T5 model after inference - offload_guardrail_models: Whether to offload guardrail models - guidance: Classifier-free guidance scale - num_steps: Number of diffusion sampling steps - height: Height of output video - width: Width of output video - fps: Frames per second of output video - num_video_frames: Number of frames to generate - seed: Random seed for sampling - num_input_frames: Number of latent conditions - control_inputs: Dictionary of control inputs for guided generation - sigma_max: Sigma max for partial denoising - blur_strength: Strength of blur applied to input - canny_threshold: Threshold for edge detection - upsample_prompt: Whether to upsample prompts using prompt upsampler model - offload_prompt_upsampler: Whether to offload prompt upsampler after use - process_group: Process group for distributed training - waymo_example: Whether to use the waymo example post-training checkpoint - """ - self.num_input_frames = num_input_frames - self.control_inputs = control_inputs - self.sigma_max = sigma_max - self.blur_strength = blur_strength - self.canny_threshold = canny_threshold - self.upsample_prompt = upsample_prompt - self.offload_prompt_upsampler = offload_prompt_upsampler - self.prompt_upsampler = None - self.upsampler_hint_key = None - self.hint_details = None - self.process_group = process_group - self.model_name = MODEL_NAME_DICT[checkpoint_name] - self.model_class = MODEL_CLASS_DICT[checkpoint_name] - self.guidance = guidance - self.num_steps = num_steps - self.height = height - self.width = width - self.fps = fps - self.num_video_frames = num_video_frames - self.seed = seed - self.regional_prompts = regional_prompts - self.region_definitions = region_definitions - - super().__init__( - checkpoint_dir=checkpoint_dir, - checkpoint_name=checkpoint_name, - has_text_input=has_text_input, - offload_network=offload_network, - offload_tokenizer=offload_tokenizer, - offload_text_encoder_model=offload_text_encoder_model, - offload_guardrail_models=offload_guardrail_models, - ) - - # Initialize prompt upsampler if needed - if self.upsample_prompt: - if int(os.environ["RANK"]) == 0: - self._push_torchrun_environ_variables() - self._init_prompt_upsampler() - self._pop_torchrun_environ_variables() - - def _push_torchrun_environ_variables(self): - dist_keys = [ - "RANK", - "LOCAL_RANK", - "WORLD_SIZE", - "LOCAL_WORLD_SIZE", - "GROUP_RANK", - "ROLE_RANK", - "ROLE_NAME", - "OMP_NUM_THREADS", - "MASTER_ADDR", - "MASTER_PORT", - "TORCHELASTIC_USE_AGENT_STORE", - "TORCHELASTIC_MAX_RESTARTS", - "TORCHELASTIC_RUN_ID", - "TORCH_NCCL_ASYNC_ERROR_HANDLING", - "TORCHELASTIC_ERROR_FILE", - ] - - self.torchrun_environ_variables = {} - for dist_key in dist_keys: - if dist_key in os.environ: - self.torchrun_environ_variables[dist_key] = os.environ[dist_key] - del os.environ[dist_key] - - def _pop_torchrun_environ_variables(self): - for dist_key in self.torchrun_environ_variables.keys(): - os.environ[dist_key] = self.torchrun_environ_variables[dist_key] - - def _init_prompt_upsampler(self): - """ - Initializes the prompt upsampler based on the provided control inputs. - - Returns: - None: Sets instance variables for prompt upsampler, hint key, and hint details - """ - vis_hint_keys = ["vis", "edge"] - other_hint_keys = ["seg", "depth"] - self.hint_details = None - - log.info("Initializing prompt upsampler...") - - if any(key in vis_hint_keys for key in self.control_inputs): - self.upsampler_hint_key = "vis" - self.hint_details = "vis" if "vis" in self.control_inputs else "edge" - elif any(key in other_hint_keys for key in self.control_inputs): - selected_hint_keys = [key for key in self.control_inputs if key in other_hint_keys] - self.upsampler_hint_key = selected_hint_keys[0] - else: - self.upsampler_hint_key = None - - if self.upsampler_hint_key: - self.prompt_upsampler = PixtralPromptUpsampler( - checkpoint_dir=self.checkpoint_dir, - offload_prompt_upsampler=self.offload_prompt_upsampler, - ) - - log.info( - f"Prompt upsampler initialized with hint key: {self.upsampler_hint_key} and hint details: {self.hint_details}" - ) - - def _process_prompt_upsampler(self, prompt, video_path, save_folder): - """ - Processes and upscales a given prompt using the prompt upsampler. - - Args: - prompt: The text prompt to upsample - video_path: Path to the input video - save_folder: Folder to save intermediate files - - Returns: - str: The upsampled prompt - """ - if not self.prompt_upsampler: - return prompt - - log.info(f"Upsampling prompt with controlnet: {self.upsampler_hint_key}") - - if self.upsampler_hint_key in ["vis"]: # input video or control input, one of them is required - # prompt upsampler for viscontrol(vis, edge) - if self.control_inputs[self.hint_details].get("input_control", None) is not None: - input_control_path = self.control_inputs[self.hint_details].get("input_control", None) - else: - hint_key = f"control_input_{self.hint_details}" - input_control_path = generate_control_input( - input_file_path=video_path, - save_folder=save_folder, - hint_key=hint_key, - blur_strength=self.blur_strength, - canny_threshold=self.canny_threshold, - ) - else: - # prompt upsampler for seg, depth - input_control_path = self.control_inputs[self.upsampler_hint_key].get("input_control", None) - - prompt = self.prompt_upsampler._prompt_upsample_with_offload(prompt=prompt, video_path=input_control_path) - return prompt - - def _load_model(self): - self.model = load_model_by_config( - config_job_name=self.model_name, - config_file="cosmos_transfer1/diffusion/config/transfer/config.py", - model_class=self.model_class, - base_checkpoint_dir=self.checkpoint_dir, - ) - - # load the hint encoders. these encoders are run along with the main model to provide additional context - def _load_network(self): - # This load seems to be non-functional for av-sample checkpoints. The base_model loading in build_model is required - if self.checkpoint_name == "": - load_network_model(self.model, "") - else: - load_network_model(self.model, f"{self.checkpoint_dir}/{self.checkpoint_name}") - if len(self.control_inputs) > 1: - hint_encoders = torch.nn.ModuleList([]) - for key, spec in self.control_inputs.items(): - if key in valid_hint_keys: - model = load_model_by_config( - config_job_name=self.model_name, - config_file="cosmos_transfer1/diffusion/config/transfer/config.py", - model_class=self.model_class, - base_checkpoint_dir=self.checkpoint_dir, - ) - log.info(f"Loading ctrl model from ckpt_path: {spec['ckpt_path']}") - load_network_model(model, spec["ckpt_path"]) - hint_encoders.append(model.model.net) - del model - torch.cuda.empty_cache() - self.model.hint_encoders = hint_encoders - else: - for _, spec in self.control_inputs.items(): - log.info(f"Loading ctrl model from ckpt_path: {spec['ckpt_path']}") - - if os.path.exists(spec["ckpt_path"]): - net_state_dict = torch.load(spec["ckpt_path"], map_location="cpu", weights_only=False) - else: - net_state_dict = torch.load( - f"{self.checkpoint_dir}/{spec['ckpt_path']}", map_location="cpu", weights_only=False - ) - non_strict_load_model(self.model.model, net_state_dict) - - if self.process_group is not None: - log.info("Enabling CP in base model") - self.model.model.net.enable_context_parallel(self.process_group) - self.model.model.base_model.net.enable_context_parallel(self.process_group) - if hasattr(self.model.model, "hint_encoders"): - log.info("Enabling CP in hint encoders") - self.model.model.hint_encoders.net.enable_context_parallel(self.process_group) - - def _load_tokenizer(self): - load_tokenizer_model(self.model, f"{self.checkpoint_dir}/{COSMOS_TOKENIZER_CHECKPOINT}") - - def _run_tokenizer_decoding(self, sample: torch.Tensor, use_batch: bool = True) -> np.ndarray: - """Decode latent samples to video frames using the tokenizer decoder. - - Args: - sample: Latent tensor from diffusion model [B, C, T, H, W] - - Returns: - np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] - with values in range [0, 255] - """ - # Decode video - if sample.shape[0] == 1 or use_batch: - video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] - else: - # Do decoding for each batch sequentially to prevent OOM. - samples = [] - for sample_i in sample: - samples += [self.model.decode(sample_i.unsqueeze(0)).cpu()] - samples = (torch.cat(samples) + 1).clamp(0, 2) / 2 - # samples = (torch.stack(samples) + 1).clamp(0, 2) / 2 - - # Stitch the patches together to form the final video. - patch_h, patch_w = samples.shape[-2:] - orig_size = (patch_w, patch_h) - aspect_ratio = detect_aspect_ratio(orig_size) - stitch_w, stitch_h = get_upscale_size(orig_size, aspect_ratio, upscale_factor=3) - n_img_w = (stitch_w - 1) // patch_w + 1 - n_img_h = (stitch_h - 1) // patch_h + 1 - overlap_size_w = overlap_size_h = 0 - if n_img_w > 1: - overlap_size_w = (n_img_w * patch_w - stitch_w) // (n_img_w - 1) - if n_img_h > 1: - overlap_size_h = (n_img_h * patch_h - stitch_h) // (n_img_h - 1) - video = merge_patches_into_video(samples, overlap_size_h, overlap_size_w, n_img_h, n_img_w) - video = torch.nn.functional.interpolate(video[0], size=(patch_h * 3, patch_w * 3), mode="bicubic")[None] - video = video.clamp(0, 1) - video = (video * 255).to(torch.uint8).cpu() - return video - - def _run_model_with_offload( - self, - prompt_embeddings: list[torch.Tensor], - video_paths: list[str], - negative_prompt_embeddings: Optional[list[torch.Tensor]] = None, - control_inputs_list: list[dict] = None, - ) -> list[np.ndarray]: - """Generate world representation with automatic model offloading. - - Wraps the core generation process with model loading/offloading logic - to minimize GPU memory usage during inference. - - Args: - prompt_embeddings: List of text embedding tensors from T5 encoder - video_paths: List of paths to input videos - negative_prompt_embeddings: Optional list of embeddings for negative prompt guidance - control_inputs_list: List of control input dictionaries - - Returns: - list[np.ndarray]: List of generated world representations as numpy arrays - """ - if self.offload_tokenizer: - self._load_tokenizer() - - if self.offload_network: - self._load_network() - - prompt_embeddings = torch.cat(prompt_embeddings) - if negative_prompt_embeddings is not None: - negative_prompt_embeddings = torch.cat(negative_prompt_embeddings) - - samples = self._run_model( - prompt_embeddings=prompt_embeddings, - negative_prompt_embeddings=negative_prompt_embeddings, - video_paths=video_paths, - control_inputs_list=control_inputs_list, - ) - - if self.offload_network: - self._offload_network() - - if self.offload_tokenizer: - self._offload_tokenizer() - - return samples - - def _run_model( - self, - prompt_embeddings: torch.Tensor, # [B, ...] - video_paths: list[str], # [B] - negative_prompt_embeddings: Optional[torch.Tensor] = None, # [B, ...] or None - control_inputs_list: list[dict] = None, # [B] list of dicts - ) -> np.ndarray: - """ - Batched world generation with model offloading. - Each batch element corresponds to a (prompt, video, control_inputs) triple. - """ - B = len(video_paths) - assert prompt_embeddings.shape[0] == B, "Batch size mismatch for prompt embeddings" - if negative_prompt_embeddings is not None: - assert negative_prompt_embeddings.shape[0] == B, "Batch size mismatch for negative prompt embeddings" - assert len(control_inputs_list) == B, "Batch size mismatch for control_inputs_list" - - log.info("Starting data augmentation") - - # Process regional prompts if provided - log.info(f"regional_prompts passed to _run_model: {self.regional_prompts}") - log.info(f"region_definitions passed to _run_model: {self.region_definitions}") - regional_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(self.regional_prompts) - regional_contexts = None - region_masks = None - if self.regional_prompts and self.region_definitions: - # Prepare regional prompts using the existing text embedding function - _, regional_contexts, region_masks = prepare_regional_prompts( - model=self.model, - global_prompt=prompt_embeddings, # Pass the already computed global embedding - regional_prompts=regional_embeddings, - region_definitions=self.region_definitions, - batch_size=1, # Adjust based on your batch size - time_dim=self.num_video_frames, - height=self.height // self.model.tokenizer.spatial_compression_factor, - width=self.width // self.model.tokenizer.spatial_compression_factor, - device=torch.device("cuda"), - compression_factor=self.model.tokenizer.spatial_compression_factor, - ) - - is_upscale_case = any("upscale" in control_inputs for control_inputs in control_inputs_list) - # Get video batch and state shape - data_batch, state_shape = get_batched_ctrl_batch( - model=self.model, - prompt_embeddings=prompt_embeddings, # [B, ...] - negative_prompt_embeddings=negative_prompt_embeddings, - height=self.height, - width=self.width, - fps=self.fps, - num_video_frames=self.num_video_frames, - input_video_paths=video_paths, # [B] - control_inputs_list=control_inputs_list, # [B] - blur_strength=self.blur_strength, - canny_threshold=self.canny_threshold, - ) - - if regional_contexts is not None: - data_batch["regional_contexts"] = regional_contexts - data_batch["region_masks"] = region_masks - - log.info("Completed data augmentation") - - hint_key = data_batch["hint_key"] - control_input = data_batch[hint_key] # [B, C, T, H, W] - input_video = data_batch.get("input_video", None) - control_weight = data_batch.get("control_weight", None) - num_new_generated_frames = self.num_video_frames - self.num_input_frames - B, C, T, H, W = control_input.shape - if (T - self.num_input_frames) % num_new_generated_frames != 0: # pad duplicate frames at the end - pad_t = num_new_generated_frames - ((T - self.num_input_frames) % num_new_generated_frames) - pad_frames = control_input[:, :, -1:].repeat(1, 1, pad_t, 1, 1) - control_input = torch.cat([control_input, pad_frames], dim=2) - if input_video is not None: - pad_video = input_video[:, :, -1:].repeat(1, 1, pad_t, 1, 1) - input_video = torch.cat([input_video, pad_video], dim=2) - num_total_frames_with_padding = control_input.shape[2] - if ( - isinstance(control_weight, torch.Tensor) - and control_weight.ndim > 5 - and num_total_frames_with_padding > control_weight.shape[3] - ): - pad_t = num_total_frames_with_padding - control_weight.shape[3] - pad_weight = control_weight[:, :, :, -1:].repeat(1, 1, 1, pad_t, 1, 1) - control_weight = torch.cat([control_weight, pad_weight], dim=3) - else: - num_total_frames_with_padding = T - N_clip = (num_total_frames_with_padding - self.num_input_frames) // num_new_generated_frames - - video = [] - prev_frames = None - for i_clip in tqdm(range(N_clip)): - # data_batch_i = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data_batch.items()} - data_batch_i = {k: v for k, v in data_batch.items()} - start_frame = num_new_generated_frames * i_clip - end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames - - # Prepare x_sigma_max - if input_video is not None: - if is_upscale_case: - x_sigma_max = [] - for b in range(B): - input_frames = input_video[b : b + 1, :, start_frame:end_frame].cuda() - x0 = self.model.encode(input_frames).contiguous() - x_sigma_max.append(self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip))) - x_sigma_max = torch.cat(x_sigma_max) - else: - input_frames = input_video[:, :, start_frame:end_frame].cuda() - x0 = self.model.encode(input_frames).contiguous() - x_sigma_max = self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip)) - - else: - x_sigma_max = None - - data_batch_i[hint_key] = control_input[:, :, start_frame:end_frame].cuda() - latent_hint = [] - log.info("Starting latent encoding") - for b in range(B): - data_batch_p = {k: v for k, v in data_batch_i.items()} - data_batch_p[hint_key] = data_batch_i[hint_key][b : b + 1] - if len(control_inputs_list) >= 1 and len(control_inputs_list[0]) > 1: - latent_hint_i = [] - for idx in range(0, data_batch_p[hint_key].size(1), 3): - x_rgb = data_batch_p[hint_key][:, idx : idx + 3] - latent_hint_i.append(self.model.encode(x_rgb)) - latent_hint.append(torch.cat(latent_hint_i).unsqueeze(0)) - else: - latent_hint.append(self.model.encode_latent(data_batch_p)) - data_batch_i["latent_hint"] = latent_hint = torch.cat(latent_hint) - log.info("Completed latent encoding") - # Resize control_weight if needed - if isinstance(control_weight, torch.Tensor) and control_weight.ndim > 4: - control_weight_t = control_weight[..., start_frame:end_frame, :, :] - t, h, w = latent_hint.shape[-3:] - data_batch_i["control_weight"] = resize_control_weight_map(control_weight_t, (t, h // 2, w // 2)) - - # Prepare condition_latent for long video generation - if i_clip == 0: - num_input_frames = 0 - latent_tmp = latent_hint if latent_hint.ndim == 5 else latent_hint[:, 0] - condition_latent = torch.zeros_like(latent_tmp) - else: - num_input_frames = self.num_input_frames - prev_frames = split_video_into_patches(prev_frames, control_input.shape[-2], control_input.shape[-1]) - input_frames = prev_frames.bfloat16().cuda() / 255.0 * 2 - 1 - condition_latent = self.model.encode(input_frames).contiguous() - - # Generate video frames for this clip (batched) - log.info("Starting diffusion sampling") - latents = generate_world_from_control( - model=self.model, - state_shape=state_shape, - is_negative_prompt=True, - data_batch=data_batch_i, - guidance=self.guidance, - num_steps=self.num_steps, - seed=(self.seed + i_clip), - condition_latent=condition_latent, - num_input_frames=num_input_frames, - sigma_max=self.sigma_max if x_sigma_max is not None else None, - x_sigma_max=x_sigma_max, - use_batch_processing=False if is_upscale_case else True, - ) - log.info("Completed diffusion sampling") - log.info("Starting VAE decode") - frames = self._run_tokenizer_decoding( - latents, use_batch=False if is_upscale_case else True - ) # [B, T, H, W, C] or similar - log.info("Completed VAE decode") - - if i_clip == 0: - video.append(frames) - else: - video.append(frames[:, :, self.num_input_frames :]) - - prev_frames = torch.zeros_like(frames) - prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] - - video = torch.cat(video, dim=2)[:, :, :T] - video = video.permute(0, 2, 3, 4, 1).numpy() - return video - - def generate( - self, - prompt: str | list[str], - video_path: str | list[str], - negative_prompt: Optional[str | list[str]] = None, - control_inputs: dict | list[dict] = None, - save_folder: str = "outputs/", - batch_size: int = 1, - ) -> tuple[np.ndarray, str | list[str]] | None: - """Generate video from text prompt and control video. - - Pipeline steps: - 1. Run safety checks on input prompt - 2. Convert prompt to embeddings - 3. Generate video frames using diffusion - 4. Run safety checks and apply face blur on generated video frames - - Args: - prompt: Text description of desired video - video_path: Path to input video - negative_prompt: Optional text to guide what not to generate - control_inputs: Control inputs for guided generation - save_folder: Folder to save intermediate files - batch_size: Number of videos to process simultaneously - - Returns: - tuple: ( - Generated video frames as uint8 np.ndarray [T, H, W, C], - Final prompt used for generation (may be enhanced) - ), or None if content fails guardrail safety checks - """ - # log.info(f"Run with prompt: {prompt}") - # log.info(f"Run with video path: {video_path}") - # log.info(f"Run with negative prompt: {negative_prompt}") - - prompts = [prompt] if isinstance(prompt, str) else prompt - video_paths = [video_path] if isinstance(video_path, str) else video_path - control_inputs_list = [control_inputs] if not isinstance(control_inputs, list) else control_inputs - - assert len(video_paths) == batch_size, "Number of prompts and videos must match" - assert len(control_inputs_list) == batch_size, "Number of control inputs must match batch size" - log.info(f"Running batch generation with batch_size={batch_size}") - - # Process prompts in batch - all_videos = [] - all_final_prompts = [] - - # Upsample prompts if enabled - if self.prompt_upsampler and int(os.environ["RANK"]) == 0: - self._push_torchrun_environ_variables() - upsampled_prompts = [] - for i, (single_prompt, single_video_path) in enumerate(zip(prompts, video_paths)): - log.info(f"Upsampling prompt {i+1}/{batch_size}: {single_prompt[:50]}...") - video_save_subfolder = os.path.join(save_folder, f"video_{i}") - os.makedirs(video_save_subfolder, exist_ok=True) - upsampled_prompt = self._process_prompt_upsampler( - single_prompt, single_video_path, video_save_subfolder - ) - upsampled_prompts.append(upsampled_prompt) - log.info(f"Upsampled prompt {i+1}: {upsampled_prompt[:50]}...") - self._pop_torchrun_environ_variables() - prompts = upsampled_prompts - - log.info("Running guardrail checks on all prompts") - safe_indices = [] - for i, single_prompt in enumerate(prompts): - is_safe = self._run_guardrail_on_prompt_with_offload(single_prompt) - if is_safe: - safe_indices.append(i) - else: - log.critical(f"Input text prompt {i+1} is not safe") - - if not safe_indices: - log.critical("All prompts failed safety checks") - return None - - safe_prompts = [prompts[i] for i in safe_indices] - safe_video_paths = [video_paths[i] for i in safe_indices] - safe_control_inputs = [control_inputs_list[i] for i in safe_indices] - - log.info("Running text embedding on all prompts") - all_prompt_embeddings = [] - - # Process in smaller batches if needed to avoid OOM - embedding_batch_size = min(batch_size, 8) # Process embeddings in smaller batches - for i in range(0, len(safe_prompts), embedding_batch_size): - batch_prompts = safe_prompts[i : i + embedding_batch_size] - if negative_prompt: - batch_prompts_with_neg = [] - for p in batch_prompts: - batch_prompts_with_neg.extend([p, negative_prompt]) - else: - batch_prompts_with_neg = batch_prompts - log.info("Starting T5 compute") - prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(batch_prompts_with_neg) - log.info("Completed T5 compute") - # Separate positive and negative embeddings - if negative_prompt: - for j in range(0, len(prompt_embeddings), 2): - all_prompt_embeddings.append((prompt_embeddings[j], prompt_embeddings[j + 1])) - else: - for emb in prompt_embeddings: - all_prompt_embeddings.append((emb, None)) - log.info("Finish text embedding on prompt") - - # Generate videos in batches - log.info("Run generation") - - all_neg_embeddings = [emb[1] for emb in all_prompt_embeddings] - all_prompt_embeddings = [emb[0] for emb in all_prompt_embeddings] - videos = self._run_model_with_offload( - prompt_embeddings=all_prompt_embeddings, - negative_prompt_embeddings=all_neg_embeddings, - video_paths=safe_video_paths, - control_inputs_list=safe_control_inputs, - ) - log.info("Finish generation") - - log.info("Run guardrail on generated videos") - for i, video in enumerate(videos): - safe_video = self._run_guardrail_on_video_with_offload(video) - if safe_video is not None: - all_videos.append(safe_video) - all_final_prompts.append(safe_prompts[i]) - else: - log.critical(f"Generated video {i+1} is not safe") - if not all_videos: - log.critical("All generated videos failed safety checks") - return None - return all_videos, all_final_prompts - - -class DiffusionControl2WorldMultiviewGenerationPipeline(DiffusionControl2WorldGenerationPipeline): - def __init__(self, *args, is_lvg_model=False, n_clip_max=-1, **kwargs): - super(DiffusionControl2WorldMultiviewGenerationPipeline, self).__init__(*args, **kwargs) - self.is_lvg_model = is_lvg_model - self.n_clip_max = n_clip_max - - def _run_tokenizer_decoding(self, sample: torch.Tensor): - """Decode latent samples to video frames using the tokenizer decoder. - - Args: - sample: Latent tensor from diffusion model [B, C, T, H, W] - - Returns: - np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C] - with values in range [0, 255] - """ - - if self.model.n_views == 5: - video_arrangement = [1, 0, 2, 3, 0, 4] - elif self.model.n_views == 6: - video_arrangement = [1, 0, 2, 4, 3, 5] - else: - raise ValueError(f"Unsupported number of views: {self.model.n_views}") - # Decode video - video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W] - video_segments = einops.rearrange(video, "b c (v t) h w -> b c v t h w", v=self.model.n_views) - grid_video = torch.stack( - [video_segments[:, :, i] for i in video_arrangement], - dim=2, - ) - grid_video = einops.rearrange(grid_video, "b c (h w) t h1 w1 -> b c t (h h1) (w w1)", h=2, w=3) - grid_video = (grid_video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() - video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy() - - return [grid_video, video] - - def _run_model_with_offload( - self, - prompt_embedding: torch.Tensor, - view_condition_video="", - initial_condition_video="", - control_inputs: dict = None, - ) -> np.ndarray: - """Generate world representation with automatic model offloading. - - Wraps the core generation process with model loading/offloading logic - to minimize GPU memory usage during inference. - - Args: - prompt_embedding: Text embedding tensor from T5 encoder - view_condition_video: Path to input sv view condition video - initial_condition_video: Path to input mv initial frames - control_inputs: Dictionary of control modalities and corresponding inputs - - Returns: - np.ndarray: Generated world representation as numpy array - """ - if self.offload_tokenizer: - self._load_tokenizer() - - if self.offload_network: - self._load_network() - - sample = self._run_model( - prompt_embedding, view_condition_video, initial_condition_video, control_inputs=control_inputs - ) - - if self.offload_network: - self._offload_network() - - if self.offload_tokenizer: - self._offload_tokenizer() - - return sample - - def _run_model( - self, - embedding: torch.Tensor, - view_condition_video="", - initial_condition_video="", - control_inputs: dict = None, - ) -> torch.Tensor: - """Generate video frames using the diffusion model. - - Args: - prompt_embedding: Text embedding tensor from T5 encoder - view_condition_video: Path to input sv view condition video - initial_condition_video: Path to input mv initial frames - control_inputs: Dictionary of control modalities and corresponding inputs - - Returns: - Tensor of generated video frames - - Note: - Model and tokenizer are automatically offloaded after inference - if offloading is enabled. - """ - # Get video batch and state shape - assert len(embedding) == self.model.n_views - - view_condition_video, fps = read_video_or_image_into_frames_BCTHW( - view_condition_video, - normalize=False, # s.t. output range is [0, 255] - max_frames=6000, - also_return_fps=True, - ) - view_condition_video = resize_video( - view_condition_video, self.height, self.width, interpolation=cv2.INTER_LINEAR - ) - view_condition_video = torch.from_numpy(view_condition_video) - total_T = view_condition_video.shape[2] - - data_batch, state_shape = get_video_batch_for_multiview_model( - model=self.model, - prompt_embedding=embedding, - height=self.height, - width=self.width, - fps=self.fps, - num_video_frames=self.num_video_frames * len(embedding), - frame_repeat_negative_condition=0, - ) - - self.model.condition_location = "first_cam_and_first_n" if self.is_lvg_model else "first_cam" - - if self.is_lvg_model: - if os.path.isdir(initial_condition_video): - initial_condition_videos = [] - fnames = sorted(os.listdir(initial_condition_video)) - for fname in fnames: - if fname.endswith(".mp4"): - try: - input_view_id = int(fname.split(".")[0]) - except ValueError: - log.warning(f"Could not parse video file name {fname} into view id") - continue - initial_condition_video_n = read_video_or_image_into_frames_BCTHW( - fname, - normalize=False, - max_frames=self.num_input_frames, - also_return_fps=True, - ) - initial_condition_videos.append(torch.from_numpy(initial_condition_video_n)) - initial_condition_video = torch.cat(initial_condition_videos, dim=2) - else: - initial_condition_video, _ = read_video_or_image_into_frames_BCTHW( - initial_condition_video, - normalize=False, - max_frames=6000, - also_return_fps=True, - ) # B C (V T) H W - initial_condition_video = torch.from_numpy(initial_condition_video) - else: - initial_condition_video = None - - data_batch = get_ctrl_batch_mv( - self.height, self.width, data_batch, total_T, control_inputs, self.model.n_views, self.num_video_frames - ) # multicontrol inputs are concatenated channel wise, [-1,1] range - - hint_key = data_batch["hint_key"] - input_video = None - control_input = data_batch[hint_key] - control_weight = data_batch["control_weight"] - - num_new_generated_frames = self.num_video_frames - self.num_input_frames # 57 - 9 = 48 - B, C, T, H, W = control_input.shape - T = T // self.model.n_views - assert T == total_T - # Different from other examples, we use a different logic to determine total generated duration: - # we check for the maximum number of clips that can be fit in to the duration of ctrl input and condition input - # and implicitly trim these videos to that duration - if self.is_lvg_model: - N_clip = (T - self.num_input_frames) // num_new_generated_frames - if self.n_clip_max > 0: - N_clip = min(self.n_clip_max, N_clip) - else: - N_clip = 1 - log.info("Model is not Long-video generation model, overwriting N_clip to 1") - - video = [] - for i_clip in tqdm(range(N_clip)): - data_batch_i = {k: v for k, v in data_batch.items()} - start_frame = num_new_generated_frames * i_clip - end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames - - if input_video is not None: - x_sigma_max = [] - for b in range(B): - input_frames = input_video[b : b + 1, :, start_frame:end_frame].cuda() - x0 = self.model.encode(input_frames).contiguous() - x_sigma_max.append(self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip))) - x_sigma_max = torch.cat(x_sigma_max) - else: - x_sigma_max = None - - control_input_BVCT = einops.rearrange(control_input, "B C (V T) H W -> (B V) C T H W", V=self.model.n_views) - control_input_i = control_input_BVCT[:, :, start_frame:end_frame].cuda() - - data_batch_i[hint_key] = einops.rearrange( - control_input_i, "(B V) C T H W -> B C (V T) H W", V=self.model.n_views - ) - - condition_input_i = view_condition_video[:, :, start_frame:end_frame].cuda() - - latent_hint = [] - for b in range(B): - data_batch_p = {k: v for k, v in data_batch_i.items()} - data_batch_p[hint_key] = data_batch_i[hint_key][b : b + 1] - if len(control_inputs) > 1: - latent_hint_i = [] - for idx in range(0, data_batch_p[hint_key].size(1), 3): - x_rgb = data_batch_p[hint_key][:, idx : idx + 3] - latent_hint_i.append(self.model.encode(x_rgb)) - latent_hint.append(torch.cat(latent_hint_i).unsqueeze(0)) - else: - latent_hint.append(self.model.encode_latent(data_batch_p)) - data_batch_i["latent_hint"] = latent_hint = torch.cat(latent_hint) - - if "regional_contexts" in data_batch_i: - data_batch_i["regional_contexts"] = broadcast(data_batch_i["regional_contexts"], to_tp=True, to_cp=True) - data_batch_i["region_masks"] = broadcast(data_batch_i["region_masks"], to_tp=True, to_cp=True) - - if isinstance(control_weight, torch.Tensor) and control_weight.ndim > 4: - control_weight_t = control_weight[..., start_frame:end_frame, :, :].cuda() - t, h, w = latent_hint.shape[-3:] - data_batch_i["control_weight"] = resize_control_weight_map(control_weight_t, (t, h // 2, w // 2)) - - if i_clip == 0: - if initial_condition_video is not None: - prev_frames_blank = torch.zeros((B, self.model.n_views, C, self.num_video_frames, H, W)).to( - view_condition_video - ) - - initial_condition_video_frames_BVCT = einops.rearrange( - initial_condition_video, "B C (V T) H W -> B V C T H W", V=self.model.n_views - ) - prev_frames_blank[:, :, :, : self.num_input_frames] = initial_condition_video_frames_BVCT[ - :, :, :, start_frame : start_frame + self.num_input_frames - ].cuda() - prev_frames = einops.rearrange(prev_frames_blank, "B V C T H W -> B C (V T) H W") - num_input_frames = self.num_input_frames - else: - num_input_frames = 0 - prev_frames = None - else: - num_input_frames = self.num_input_frames - condition_latent = self.get_condition_latent( - state_shape, - data_batch_i, - cond_video=condition_input_i, - prev_frames=prev_frames, - patch_h=H, - patch_w=W, - skip_reencode=False, - ).bfloat16() - # Generate video frames - latents = generate_world_from_control( - model=self.model, - state_shape=self.model.state_shape, - is_negative_prompt=False, - data_batch=data_batch_i, - guidance=self.guidance, - num_steps=self.num_steps, - seed=(self.seed + i_clip), - condition_latent=condition_latent, - num_input_frames=num_input_frames, - sigma_max=self.sigma_max if x_sigma_max is not None else None, - x_sigma_max=x_sigma_max, - augment_sigma=0.0, - ) - torch.cuda.empty_cache() - _, frames = self._run_tokenizer_decoding(latents) # T H W C - frames = torch.from_numpy(frames).permute(3, 0, 1, 2)[None] # 1 C (V T) H W - frames_BVCT = einops.rearrange(frames, "B C (V T) H W -> B V C T H W", V=self.model.n_views) - if i_clip == 0: - video.append(frames_BVCT) - else: - frames_BVCT_non_overlap = frames_BVCT[:, :, :, num_input_frames:] - video.append(frames_BVCT_non_overlap) - - prev_frames = torch.zeros_like(frames_BVCT) - n_copy = max(1, abs(self.num_input_frames)) - prev_frames[:, :, :, :n_copy] = frames_BVCT[:, :, :, -n_copy:] - prev_frames = einops.rearrange(prev_frames, "B V C T H W -> B C (V T) H W") - - video = torch.cat(video, dim=3) - video = einops.rearrange(video, "B V C T H W -> B C (V T) H W") - video = video[0].permute(1, 2, 3, 0).numpy() # T H W C - return video - - def get_condition_latent( - self, - state_shape, - data_batch_i, - cond_video=None, - prev_frames=None, - patch_h=1024, - patch_w=1024, - skip_reencode=False, - prev_latents=None, - ): - """ - Create the condition latent used in this loop for generation from RGB frames - Args: - model: - state_shape: tuple (C T H W), shape of latent to be generated - data_batch_i: (dict) this is only used to get batch size - multi_cam: (bool) whether to use multicam processing or revert to original behavior from tpsp_demo - cond_video: (tensor) the front view video for conditioning sv2mv - prev_frames: (tensor) frames generated in previous loop - patch_h: (int) - patch_w: (int) - skip_reencode: (bool) whether to use the tokenizer to encode prev_frames, or read from prev_latents directly - prev_latents: (tensor) latent generated in previous loop, must not be None if skip_reencode - - Returns: - - """ - # this might be not 1 when patching is used - B = data_batch_i["video"].shape[0] - - latent_sample = torch.zeros(state_shape).unsqueeze(0).repeat(B, 1, 1, 1, 1).cuda() # B, C, (V T), H, W - latent_sample = einops.rearrange(latent_sample, "B C (V T) H W -> B V C T H W", V=self.model.n_views) - log.info(f"model.sigma_data {self.model.sigma_data}") - if self.model.config.conditioner.video_cond_bool.condition_location.endswith("first_n"): - if skip_reencode: - assert prev_latents is not None - prev_latents = einops.rearrange(prev_latents, "B C (V T) H W -> B V C T H W", V=self.model.n_views) - latent_sample = prev_latents.clone() - else: - prev_frames = split_video_into_patches(prev_frames, patch_h, patch_w) - for b in range(prev_frames.shape[0]): - input_frames = prev_frames[b : b + 1].cuda() / 255.0 * 2 - 1 - input_frames = einops.rearrange(input_frames, "1 C (V T) H W -> V C T H W", V=self.model.n_views) - encoded_frames = self.model.tokenizer.encode(input_frames).contiguous() * self.model.sigma_data - latent_sample[b : b + 1, :] = encoded_frames - - if self.model.config.conditioner.video_cond_bool.condition_location.startswith("first_cam"): - assert cond_video is not None - cond_video = split_video_into_patches(cond_video, patch_h, patch_w) - for b in range(cond_video.shape[0]): - input_frames = cond_video[b : b + 1].cuda() / 255.0 * 2 - 1 - # input_frames = einops.rearrange(input_frames, "1 C (V T) H W -> V C T H W", V=self.model.n_views)[:1] - latent_sample[ - b : b + 1, - 0, - ] = ( - self.model.tokenizer.encode(input_frames).contiguous() * self.model.sigma_data - ) - - latent_sample = einops.rearrange(latent_sample, " B V C T H W -> B C (V T) H W") - log.info(f"latent_sample, {latent_sample[:,0,:,0,0]}") - - return latent_sample - - def build_mv_prompt(self, mv_prompts, n_views): - """ - Apply multiview prompt formatting to the input prompt such that hte text conditioning matches that used during - training. - Args: - prompt: caption of one scene, with prompt of each view separated by ";" - n_views: number of cameras to format the caption to - - Returns: - - """ - if n_views == 5: - base_prompts = [ - "The video is captured from a camera mounted on a car. The camera is facing forward.", - "The video is captured from a camera mounted on a car. The camera is facing to the left.", - "The video is captured from a camera mounted on a car. The camera is facing to the right.", - "The video is captured from a camera mounted on a car. The camera is facing the rear left side.", - "The video is captured from a camera mounted on a car. The camera is facing the rear right side.", - ] - elif n_views == 6: - base_prompts = [ - "The video is captured from a camera mounted on a car. The camera is facing forward.", - "The video is captured from a camera mounted on a car. The camera is facing to the left.", - "The video is captured from a camera mounted on a car. The camera is facing to the right.", - "The video is captured from a camera mounted on a car. The camera is facing backwards.", - "The video is captured from a camera mounted on a car. The camera is facing the rear left side.", - "The video is captured from a camera mounted on a car. The camera is facing the rear right side.", - ] - - log.info(f"Reading multiview prompts, found {len(mv_prompts)} splits") - n = len(mv_prompts) - if n < n_views: - mv_prompts += base_prompts[n:] - else: - mv_prompts = mv_prompts[:n_views] - - for vid, p in enumerate(mv_prompts): - if not p.startswith(base_prompts[vid]): - mv_prompts[vid] = base_prompts[vid] + " " + p - log.info(f"Adding missing camera caption to view {vid}, {p[:30]}") - - log.info(f"Procced multiview prompts, {len(mv_prompts)} splits") - return mv_prompts - - def generate( - self, - prompts: list, - view_condition_video: str, - initial_condition_video: str, - control_inputs: dict = None, - save_folder: str = "outputs/", - ) -> tuple[np.ndarray, str] | None: - """Generate video from text prompt and control video. - - Pipeline steps: - 1. Run safety checks on input prompt - 2. Convert prompt to embeddings - 3. Generate video frames using diffusion - 4. Run safety checks and apply face blur on generated video frames - - Args: - control_inputs: Control inputs for guided generation - save_folder: Folder to save intermediate files - - Returns: - tuple: ( - Generated video frames as uint8 np.ndarray [T, H, W, C], - Final prompt used for generation (may be enhanced) - ), or None if content fails guardrail safety checks - """ - - log.info(f"Run with view condition video path: {view_condition_video}") - if initial_condition_video: - log.info(f"Run with initial condition video path: {initial_condition_video}") - mv_prompts = self.build_mv_prompt(prompts, self.model.n_views) - log.info(f"Run with prompt: {mv_prompts}") - - # Process prompts into multiview format - log.info("Run guardrail on prompt") - is_safe = self._run_guardrail_on_prompt_with_offload(". ".join(mv_prompts)) - if not is_safe: - log.critical("Input text prompt is not safe") - return None - log.info("Pass guardrail on prompt") - - prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(mv_prompts) - prompt_embedding = torch.concat(prompt_embeddings, dim=0).cuda() - - log.info("Finish text embedding on prompt") - - # Generate video - log.info("Run generation") - - video = self._run_model_with_offload( - prompt_embedding, - view_condition_video, - initial_condition_video, - control_inputs=control_inputs, - ) - log.info("Finish generation") - log.info("Run guardrail on generated video") - video = self._run_guardrail_on_video_with_offload(video) - if video is None: - log.critical("Generated video is not safe") - raise ValueError("Guardrail check failed: Generated video is unsafe") - - log.info("Pass guardrail on generated video") - - return video, mv_prompts - - -class DistilledControl2WorldGenerationPipeline(DiffusionControl2WorldGenerationPipeline): - """Pipeline for distilled ControlNet video2video inference.""" - - def _load_network(self): - log.info("Loading distilled consolidated checkpoint") - - # Load consolidated checkpoint - from cosmos_transfer1.diffusion.inference.inference_utils import skip_init_linear - - with skip_init_linear(): - self.model.set_up_model() - checkpoint_path = f"{self.checkpoint_dir}/{self.checkpoint_name}" - checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) - state_dict = checkpoint.get("model", checkpoint) - - # Split into base and control components - base_state_dict = {} - ctrl_state_dict = {} - - for k, v in state_dict.items(): - if k.startswith("net.base_model.net."): - base_key = k[len("net.base_model.net.") :] - base_state_dict[base_key] = v - elif k.startswith("net.net_ctrl."): - ctrl_key = k[len("net.net_ctrl.") :] - ctrl_state_dict[ctrl_key] = v - - # Load base model weights - if base_state_dict: - self.model.model["net"].base_model.net.load_state_dict(base_state_dict, strict=False) - self.model.model.base_model.load_state_dict(base_state_dict, strict=False) - # Load control weights - if ctrl_state_dict: - self.model.model["net"].net_ctrl.load_state_dict(ctrl_state_dict, strict=False) - self.model.model.cuda() - - if self.process_group is not None: - log.info("Enabling CP in base model") - self.model.model.net.enable_context_parallel(self.process_group) - - def _run_model( - self, - prompt_embeddings: torch.Tensor, # [B, ...] - video_paths: list[str], # [B] - negative_prompt_embeddings: Optional[torch.Tensor] = None, # [B, ...] or None - control_inputs_list: list[dict] = None, # [B] list of dicts - ) -> np.ndarray: - """ - Batched world generation with model offloading. - Each batch element corresponds to a (prompt, video, control_inputs) triple. - """ - B = len(video_paths) - print(f"video paths: {video_paths}") - assert prompt_embeddings.shape[0] == B, "Batch size mismatch for prompt embeddings" - if negative_prompt_embeddings is not None: - assert negative_prompt_embeddings.shape[0] == B, "Batch size mismatch for negative prompt embeddings" - assert len(control_inputs_list) == B, "Batch size mismatch for control_inputs_list" - - log.info("Starting data augmentation") - - log.info(f"Regional prompts not supported when using distilled model, dropping: {self.regional_prompts}") - - # Get video batch and state shape - data_batch, state_shape = get_batched_ctrl_batch( - model=self.model, - prompt_embeddings=prompt_embeddings, # [B, ...] - negative_prompt_embeddings=negative_prompt_embeddings, - height=self.height, - width=self.width, - fps=self.fps, - num_video_frames=self.num_video_frames, - input_video_paths=video_paths, # [B] - control_inputs_list=control_inputs_list, # [B] - blur_strength=self.blur_strength, - canny_threshold=self.canny_threshold, - ) - - log.info("Completed data augmentation") - - hint_key = data_batch["hint_key"] - control_input = data_batch[hint_key] # [B, C, T, H, W] - input_video = data_batch.get("input_video", None) - control_weight = data_batch.get("control_weight", None) - num_new_generated_frames = self.num_video_frames - self.num_input_frames - B, C, T, H, W = control_input.shape - if (T - self.num_input_frames) % num_new_generated_frames != 0: # pad duplicate frames at the end - pad_t = num_new_generated_frames - ((T - self.num_input_frames) % num_new_generated_frames) - pad_frames = control_input[:, :, -1:].repeat(1, 1, pad_t, 1, 1) - control_input = torch.cat([control_input, pad_frames], dim=2) - if input_video is not None: - pad_video = input_video[:, :, -1:].repeat(1, 1, pad_t, 1, 1) - input_video = torch.cat([input_video, pad_video], dim=2) - num_total_frames_with_padding = control_input.shape[2] - if ( - isinstance(control_weight, torch.Tensor) - and control_weight.ndim > 5 - and num_total_frames_with_padding > control_weight.shape[3] - ): - pad_t = num_total_frames_with_padding - control_weight.shape[3] - pad_weight = control_weight[:, :, :, -1:].repeat(1, 1, 1, pad_t, 1, 1) - control_weight = torch.cat([control_weight, pad_weight], dim=3) - else: - num_total_frames_with_padding = T - N_clip = (num_total_frames_with_padding - self.num_input_frames) // num_new_generated_frames - - video = [] - initial_condition_input = None - - prev_frames = None - if input_video is not None: - prev_frames = torch.zeros_like(input_video).cuda() - prev_frames[:, :, : self.num_input_frames] = (input_video[:, :, : self.num_input_frames] + 1) * 255.0 / 2 - log.info(f"N_clip: {N_clip}") - for i_clip in tqdm(range(N_clip)): - log.info(f"input_video shape: {input_video.shape}") - # data_batch_i = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data_batch.items()} - data_batch_i = {k: v for k, v in data_batch.items()} - start_frame = num_new_generated_frames * i_clip - end_frame = num_new_generated_frames * (i_clip + 1) + self.num_input_frames - - # Prepare x_sigma_max - if input_video is not None: - input_frames = input_video[:, :, start_frame:end_frame].cuda() - x0 = self.model.encode(input_frames).contiguous() - x_sigma_max = self.model.get_x_from_clean(x0, self.sigma_max, seed=(self.seed + i_clip)) - else: - assert False - x_sigma_max = None - - data_batch_i[hint_key] = control_input[:, :, start_frame:end_frame].cuda() - latent_hint = [] - log.info("Starting latent encoding") - for b in range(B): - data_batch_p = {k: v for k, v in data_batch_i.items()} - data_batch_p[hint_key] = data_batch_i[hint_key][b : b + 1] - if len(control_inputs_list) >= 1 and len(control_inputs_list[0]) > 1: - latent_hint_i = [] - for idx in range(0, data_batch_p[hint_key].size(1), 3): - x_rgb = data_batch_p[hint_key][:, idx : idx + 3] - latent_hint_i.append(self.model.encode(x_rgb)) - latent_hint.append(torch.cat(latent_hint_i).unsqueeze(0)) - else: - latent_hint.append(self.model.encode_latent(data_batch_p)) - data_batch_i["latent_hint"] = latent_hint = torch.cat(latent_hint) - log.info("Completed latent encoding") - - # Resize control_weight if needed - if isinstance(control_weight, torch.Tensor) and control_weight.ndim > 4: - control_weight_t = control_weight[..., start_frame:end_frame, :, :] - t, h, w = latent_hint.shape[-3:] - data_batch_i["control_weight"] = resize_control_weight_map(control_weight_t, (t, h // 2, w // 2)) - - num_input_frames = self.num_input_frames - prev_frames_patched = split_video_into_patches( - prev_frames, control_input.shape[-2], control_input.shape[-1] - ) - input_frames = prev_frames_patched.bfloat16() / 255.0 * 2 - 1 - condition_latent = self.model.encode(input_frames).contiguous() - - # Generate video frames for this clip (batched) - log.info("Starting diffusion sampling") - latents = generate_world_from_control( - model=self.model, - state_shape=state_shape, - is_negative_prompt=True, - data_batch=data_batch_i, - guidance=self.guidance, - num_steps=self.num_steps, - seed=(self.seed + i_clip), - condition_latent=condition_latent, - num_input_frames=num_input_frames, - sigma_max=self.sigma_max if x_sigma_max is not None else None, - x_sigma_max=x_sigma_max, - ) - log.info("Completed diffusion sampling") - - log.info("Starting VAE decode") - frames = self._run_tokenizer_decoding(latents) # [B, T, H, W, C] or similar - log.info("Completed VAE decode") - - if i_clip == 0: - video.append(frames) - else: - video.append(frames[:, :, self.num_input_frames :]) - - prev_frames = torch.zeros_like(frames) - prev_frames[:, :, : self.num_input_frames] = frames[:, :, -self.num_input_frames :] - - video = torch.cat(video, dim=2)[:, :, :T] - video = video.permute(0, 2, 3, 4, 1).numpy() - return video diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_ctrl.py b/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_ctrl.py deleted file mode 100644 index 1b76aed55a8872921a9b7da20139e8b15fb9fdef..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_ctrl.py +++ /dev/null @@ -1,1003 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, VideoConditionerWithCtrl, VideoExtendCondition -from cosmos_transfer1.diffusion.inference.inference_utils import merge_patches_into_video, split_video_into_patches -from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition -from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel, DistillV2WModel -from cosmos_transfer1.diffusion.module.parallel import broadcast, cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.diffusion.networks.distill_controlnet_wrapper import DistillControlNet -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate - -T = TypeVar("T") -IS_PREPROCESSED_KEY = "is_preprocessed" - - -class VideoDiffusionModelWithCtrl(DiffusionV2WModel): - def build_model(self) -> torch.nn.ModuleDict: - log.info("Start creating base model") - base_model = super().build_model() - # initialize base model - self.load_base_model(base_model) - log.info("Done creating base model") - - log.info("Start creating ctrlnet model") - net = lazy_instantiate(self.config.net_ctrl) - conditioner = base_model.conditioner - logvar = base_model.logvar - # initialize controlnet encoder - model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) - model.load_state_dict(base_model.state_dict(), strict=False) - - model.base_model = base_model - log.info("Done creating ctrlnet model") - - self.hint_key = self.config.hint_key["hint_key"] - return model - - @property - def base_net(self): - return self.model.base_model.net - - @property - def conditioner(self): - return self.model.conditioner - - def load_base_model(self, base_model) -> None: - config = self.config - if config.base_load_from is not None: - checkpoint_path = config.base_load_from["load_path"] - else: - checkpoint_path = "" - - if checkpoint_path: - log.info(f"Loading base model checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) - log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}") - - if "ema" in state_dict and state_dict["ema"] is not None: - # Copy the base model weights from ema model. - log.info("Copying ema to base model") - base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} - elif "model" in state_dict: - # Copy the base model weights from reg model. - log.warning("Using non-EMA base model") - base_state_dict = state_dict["model"] - else: - log.info("Loading from an EMA only model") - base_state_dict = state_dict - missing, unexpected = base_model.load_state_dict(base_state_dict, strict=False) - log.info(f"Missing: {missing}") - log.info(f"Unexpected: {unexpected}") - log.info("Done loading the base model checkpoint.") - - def get_data_and_condition( - self, data_batch: dict[str, Tensor], **kwargs - ) -> Tuple[Tensor, VideoConditionerWithCtrl]: - # process the control input - hint_key = self.config.hint_key["hint_key"] - _data = {hint_key: data_batch[hint_key]} - if IS_PREPROCESSED_KEY in data_batch: - _data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] - data_batch[hint_key] = _data[hint_key] - data_batch["hint_key"] = hint_key - raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) - use_multicontrol = ( - ("control_weight" in data_batch) - and not isinstance(data_batch["control_weight"], float) - and data_batch["control_weight"].shape[0] > 1 - ) - if use_multicontrol: # encode individual conditions separately - latent_hint = [] - num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 - for i in range(num_conditions): - cond_mask = [False] * num_conditions - cond_mask[i] = True - latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] - latent_hint = torch.cat(latent_hint) - else: - latent_hint = self.encode_latent(data_batch) - - # add extra conditions - data_batch["latent_hint"] = latent_hint - setattr(condition, hint_key, latent_hint) - setattr(condition, "base_model", self.model.base_model) - return raw_state, latent_state, condition - - def get_x_from_clean( - self, - in_clean_img: torch.Tensor, - sigma_max: float | None, - seed: int = 1, - ) -> Tensor: - """ - in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising - sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video - """ - if in_clean_img is None: - return None - generator = torch.Generator(device=self.tensor_kwargs["device"]) - generator.manual_seed(seed) - noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) - if sigma_max is None: - sigma_max = self.sde.sigma_max - x_sigma_max = in_clean_img + noise * sigma_max - return x_sigma_max - - def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: - x = data_batch[data_batch["hint_key"]] - latent = [] - # control input goes through tokenizer, which always takes 3-input channels - num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension - if num_conditions > 1 and self.config.hint_dropout_rate > 0: - if torch.is_grad_enabled(): # during training, randomly dropout some conditions - cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate - if not cond_mask.any(): # make sure at least one condition is present - cond_mask = [True] * num_conditions - elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used - cond_mask = self.config.hint_mask - else: - cond_mask = [True] * num_conditions - for idx in range(0, x.size(1), 3): - x_rgb = x[:, idx : idx + 3] - if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image - x_rgb = torch.zeros_like(x_rgb) - latent.append(self.encode(x_rgb)) - latent = torch.cat(latent, dim=1) - return latent - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - seed: int = 1, - target_h: int = 88, - target_w: int = 160, - patch_h: int = 88, - patch_w: int = 160, - use_batch_processing: bool = True, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. - - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - target_h (int): final stitched latent height - - target_w (int): final stitched latent width - - patch_h (int): latent patch height for each network inference - - patch_w (int): latent patch width for each network inference - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - # data_batch should be the one processed by self.get_data_and_condition - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - # Add conditions for long video generation. - - if condition_latent is None: - condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) - num_condition_t = 0 - condition_video_augment_sigma_in_inference = 1000 - - if use_batch_processing: - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - else: - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent[:1], condition, num_condition_t - ) - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent[:1], uncondition, num_condition_t - ) - condition.video_cond_bool = True - uncondition.video_cond_bool = False # Not do cfg on condition frames - - # Add extra conditions for ctrlnet. - latent_hint = data_batch["latent_hint"] - hint_key = data_batch["hint_key"] - setattr(condition, hint_key, latent_hint) - if "use_none_hint" in data_batch and data_batch["use_none_hint"]: - setattr(uncondition, hint_key, None) - else: - setattr(uncondition, hint_key, latent_hint) - - # Add extra conditions for ctrlnet. - # Handle regional prompting information - if "regional_contexts" in data_batch and "region_masks" in data_batch: - setattr(condition, "regional_contexts", data_batch["regional_contexts"]) - setattr(condition, "region_masks", data_batch["region_masks"]) - # For unconditioned generation, we still need the region masks but not the regional contexts - setattr(uncondition, "region_masks", data_batch["region_masks"]) - setattr(uncondition, "regional_contexts", None) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) - - cp_group = parallel_state.get_context_parallel_group() - latent_hint = getattr(condition, hint_key) - seq_dim = 3 if latent_hint.ndim == 6 else 2 - latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group) - setattr(condition, hint_key, latent_hint) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint) - if hasattr(condition, "regional_contexts") and getattr(condition, "regional_contexts") is not None: - regional_contexts = getattr(condition, "regional_contexts") - regional_contexts = split_inputs_cp(regional_contexts, seq_dim=2, cp_group=cp_group) - setattr(condition, "regional_contexts", regional_contexts) - - if hasattr(condition, "region_masks") and getattr(condition, "region_masks") is not None: - region_masks = getattr(condition, "region_masks") - region_masks = split_inputs_cp(region_masks, seq_dim=2, cp_group=cp_group) - setattr(condition, "region_masks", region_masks) - setattr(uncondition, "region_masks", region_masks) - - setattr(condition, "base_model", self.model.base_model) - setattr(uncondition, "base_model", self.model.base_model) - if hasattr(self, "hint_encoders"): - self.model.net.hint_encoders = self.hint_encoders - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor): - B = noise_x.shape[0] # Batch dimension - w, h = target_w, target_h - n_img_w = (w - 1) // patch_w + 1 - n_img_h = (h - 1) // patch_h + 1 - - overlap_size_w = overlap_size_h = 0 - if n_img_w > 1: - overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) - assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w - if n_img_h > 1: - overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) - assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h - - if use_batch_processing: - condition.gt_latent = condition_latent - uncondition.gt_latent = condition_latent - setattr(condition, hint_key, latent_hint) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint) - # Batch denoising - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) - - merged = merge_patches_into_video(x0, overlap_size_h, overlap_size_w, n_img_h, n_img_w) - return split_video_into_patches(merged, patch_h, patch_w) - - batch_images = noise_x - batch_sigma = sigma - output = [] - for idx, cur_images in enumerate(batch_images): - noise_x = cur_images.unsqueeze(0) - sigma = batch_sigma[idx : idx + 1] - condition.gt_latent = condition_latent[idx : idx + 1] - uncondition.gt_latent = condition_latent[idx : idx + 1] - setattr(condition, hint_key, latent_hint[idx : idx + 1]) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint[idx : idx + 1]) - - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) - output.append(x0) - output = rearrange(torch.stack(output), "(n t) b ... -> (b n t) ...", n=n_img_h, t=n_img_w) - final_output = merge_patches_into_video(output, overlap_size_h, overlap_size_w, n_img_h, n_img_w) - final_output = split_video_into_patches(final_output, patch_h, patch_w) - return final_output - - return x0_fn - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - target_h: int = 88, - target_w: int = 160, - patch_h: int = 88, - patch_w: int = 160, - use_batch_processing: bool = True, - ) -> Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. - If this feature is stablized, we could consider to move this function to the base model. - - Args: - condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. - num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half - """ - assert patch_h <= target_h and patch_w <= target_w - if n_sample is None: - input_key = self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - x0_fn = self.get_x0_fn_from_batch( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - target_h=target_h, - target_w=target_w, - patch_h=patch_h, - patch_w=patch_w, - use_batch_processing=use_batch_processing, - ) - - if sigma_max is None: - sigma_max = self.sde.sigma_max - - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) - - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - return samples - - -class VideoDiffusionT2VModelWithCtrl(DiffusionT2WModel): - def build_model(self) -> torch.nn.ModuleDict: - log.info("Start creating base model") - base_model = super().build_model() - # initialize base model - config = self.config - self.load_base_model(base_model) - log.info("Done creating base model") - - log.info("Start creating ctrlnet model") - net = lazy_instantiate(self.config.net_ctrl) - conditioner = base_model.conditioner - logvar = base_model.logvar - # initialize controlnet encoder - model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) - model.load_state_dict(base_model.state_dict(), strict=False) - - model.base_model = base_model - log.info("Done creating ctrlnet model") - - self.hint_key = self.config.hint_key["hint_key"] - return model - - @property - def base_net(self): - return self.model.base_model.net - - @property - def conditioner(self): - return self.model.conditioner - - def load_base_model(self, base_model) -> None: - config = self.config - if config.base_load_from is not None: - checkpoint_path = config.base_load_from["load_path"] - else: - checkpoint_path = "" - if checkpoint_path: - log.info(f"Loading base model checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) - log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}") - - if "ema" in state_dict: - # Copy the base model weights from ema model. - log.info("Copying ema to base model") - base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} - elif "model" in state_dict: - # Copy the base model weights from reg model. - log.warning("Using non-EMA base model") - base_state_dict = state_dict["model"] - else: - log.info("Loading from an EMA only model") - base_state_dict = state_dict - base_model.load_state_dict(base_state_dict, strict=False) - log.info("Done loading the base model checkpoint.") - - def get_data_and_condition( - self, data_batch: dict[str, Tensor], **kwargs - ) -> Tuple[Tensor, VideoConditionerWithCtrl]: - # process the control input - hint_key = self.config.hint_key["hint_key"] - _data = {hint_key: data_batch[hint_key]} - if IS_PREPROCESSED_KEY in data_batch: - _data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] - data_batch[hint_key] = _data[hint_key] - data_batch["hint_key"] = hint_key - raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) - use_multicontrol = ( - ("control_weight" in data_batch) - and not isinstance(data_batch["control_weight"], float) - and data_batch["control_weight"].shape[0] > 1 - ) - if use_multicontrol: # encode individual conditions separately - latent_hint = [] - num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 - for i in range(num_conditions): - cond_mask = [False] * num_conditions - cond_mask[i] = True - latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] - latent_hint = torch.cat(latent_hint) - else: - latent_hint = self.encode_latent(data_batch) - - # add extra conditions - data_batch["latent_hint"] = latent_hint - setattr(condition, hint_key, latent_hint) - setattr(condition, "base_model", self.model.base_model) - return raw_state, latent_state, condition - - def get_x_from_clean( - self, - in_clean_img: torch.Tensor, - sigma_max: float | None, - seed: int = 1, - ) -> Tensor: - """ - in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising - sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video - """ - if in_clean_img is None: - return None - generator = torch.Generator(device=self.tensor_kwargs["device"]) - generator.manual_seed(seed) - noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) - if sigma_max is None: - sigma_max = self.sde.sigma_max - x_sigma_max = in_clean_img + noise * sigma_max - return x_sigma_max - - def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: - x = data_batch[data_batch["hint_key"]] - latent = [] - # control input goes through tokenizer, which always takes 3-input channels - num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension - if num_conditions > 1 and self.config.hint_dropout_rate > 0: - if torch.is_grad_enabled(): # during training, randomly dropout some conditions - cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate - if not cond_mask.any(): # make sure at least one condition is present - cond_mask = [True] * num_conditions - elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used - cond_mask = self.config.hint_mask - else: - cond_mask = [True] * num_conditions - for idx in range(0, x.size(1), 3): - x_rgb = x[:, idx : idx + 3] - if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image - x_rgb = torch.zeros_like(x_rgb) - latent.append(self.encode(x_rgb)) - latent = torch.cat(latent, dim=1) - return latent - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. - - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - # data_batch should be the one processed by self.get_data_and_condition - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - # Add extra conditions for ctrlnet. - latent_hint = data_batch["latent_hint"] - hint_key = data_batch["hint_key"] - setattr(condition, hint_key, latent_hint) - if "use_none_hint" in data_batch and data_batch["use_none_hint"]: - setattr(uncondition, hint_key, None) - else: - setattr(uncondition, hint_key, latent_hint) - - # Handle regional prompting information - if "regional_contexts" in data_batch and "region_masks" in data_batch: - setattr(condition, "regional_contexts", data_batch["regional_contexts"]) - setattr(condition, "region_masks", data_batch["region_masks"]) - # For unconditioned generation, we still need the region masks but not the regional contexts - setattr(uncondition, "region_masks", data_batch["region_masks"]) - setattr(uncondition, "regional_contexts", None) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) - - cp_group = parallel_state.get_context_parallel_group() - latent_hint = getattr(condition, hint_key) - seq_dim = 3 if latent_hint.ndim == 6 else 2 - latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group) - setattr(condition, hint_key, latent_hint) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint) - - if hasattr(condition, "regional_contexts") and getattr(condition, "regional_contexts") is not None: - regional_contexts = getattr(condition, "regional_contexts") - regional_contexts = split_inputs_cp(regional_contexts, seq_dim=2, cp_group=cp_group) - setattr(condition, "regional_contexts", regional_contexts) - - if hasattr(condition, "region_masks") and getattr(condition, "region_masks") is not None: - region_masks = getattr(condition, "region_masks") - region_masks = split_inputs_cp(region_masks, seq_dim=2, cp_group=cp_group) - setattr(condition, "region_masks", region_masks) - setattr(uncondition, "region_masks", region_masks) - setattr(condition, "base_model", self.model.base_model) - setattr(uncondition, "base_model", self.model.base_model) - if hasattr(self, "hint_encoders"): - self.model.net.hint_encoders = self.hint_encoders - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - ).x0 - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - ).x0 - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - **kwargs, - ) -> Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. - If this feature is stablized, we could consider to move this function to the base model. - - Args: - condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. - num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half - """ - if n_sample is None: - input_key = self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - x0_fn = self.get_x0_fn_from_batch( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - ) - - if sigma_max is None: - sigma_max = self.sde.sigma_max - - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) - - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - return samples - - -class VideoDistillModelWithCtrl(DistillV2WModel): - def build_model(self) -> torch.nn.ModuleDict: - log.info("Start creating base model") - base_model = super().build_model() - # initialize base model - log.info("Done creating base model") - - log.info("Start creating ctrlnet model") - net = DistillControlNet(self.config) - net.base_model.net.load_state_dict(base_model["net"].state_dict()) - conditioner = base_model.conditioner - logvar = base_model.logvar - # initialize controlnet encoder - model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) - model.base_model = net.base_model.net - - self.hint_key = self.config.hint_key["hint_key"] - return model - - @property - def base_net(self): - return self.model.base_model.net - - @property - def conditioner(self): - return self.model.conditioner - - def get_data_and_condition( - self, data_batch: dict[str, Tensor], **kwargs - ) -> Tuple[Tensor, VideoConditionerWithCtrl]: - # process the control input - hint_key = self.config.hint_key["hint_key"] - _data = {hint_key: data_batch[hint_key]} - if IS_PREPROCESSED_KEY in data_batch: - _data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] - data_batch[hint_key] = _data[hint_key] - data_batch["hint_key"] = hint_key - raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) - use_multicontrol = ( - ("control_weight" in data_batch) - and not isinstance(data_batch["control_weight"], float) - and data_batch["control_weight"].shape[0] > 1 - ) - if use_multicontrol: # encode individual conditions separately - latent_hint = [] - num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 - for i in range(num_conditions): - cond_mask = [False] * num_conditions - cond_mask[i] = True - latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] - latent_hint = torch.cat(latent_hint) - else: - latent_hint = self.encode_latent(data_batch) - - # add extra conditions - data_batch["latent_hint"] = latent_hint - setattr(condition, hint_key, latent_hint) - setattr(condition, "base_model", self.model.base_model) - return raw_state, latent_state, condition - - def get_x_from_clean( - self, - in_clean_img: torch.Tensor, - sigma_max: float | None, - seed: int = 1, - ) -> Tensor: - """ - in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising - sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video - """ - if in_clean_img is None: - return None - generator = torch.Generator(device=self.tensor_kwargs["device"]) - generator.manual_seed(seed) - noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) - if sigma_max is None: - sigma_max = self.sde.sigma_max - x_sigma_max = in_clean_img + noise * sigma_max - return x_sigma_max - - def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: - x = data_batch[data_batch["hint_key"]] - latent = [] - # control input goes through tokenizer, which always takes 3-input channels - num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension - if num_conditions > 1 and self.config.hint_dropout_rate > 0: - if torch.is_grad_enabled(): # during training, randomly dropout some conditions - cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate - if not cond_mask.any(): # make sure at least one condition is present - cond_mask = [True] * num_conditions - elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used - cond_mask = self.config.hint_mask - else: - cond_mask = [True] * num_conditions - for idx in range(0, x.size(1), 3): - x_rgb = x[:, idx : idx + 3] - if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image - x_rgb = torch.zeros_like(x_rgb) - latent.append(self.encode(x_rgb)) - latent = torch.cat(latent, dim=1) - return latent - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 1, # Ignored for distilled models - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - target_h: int = 88, - target_w: int = 160, - patch_h: int = 88, - patch_w: int = 160, - **kwargs, - ) -> torch.Tensor: - """Single-step generation matching internal distilled model""" - - # Same preprocessing as base class - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - - if n_sample is None: - # input_key = self.input_image_key if is_image_batch else self.input_data_key - input_key = self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - if sigma_max is None: - sigma_max = self.sde.sigma_max - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - # Generate initial noise - batch_shape = (n_sample, *state_shape) - generator = torch.Generator(device=self.tensor_kwargs["device"]) - generator.manual_seed(seed) - random_noise = torch.randn(*batch_shape, generator=generator, **self.tensor_kwargs) - - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - # Handle conditioning - if condition_latent is None: - condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) - num_condition_t = 0 - condition_video_augment_sigma_in_inference = 1000 - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - uncondition.video_cond_bool = True # Not do cfg on condition frames - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - uncondition.condition_video_indicator = condition.condition_video_indicator.clone() - uncondition.condition_video_input_mask = condition.condition_video_input_mask.clone() - - latent_hint = data_batch["latent_hint"] - hint_key = data_batch["hint_key"] - setattr(condition, hint_key, latent_hint) - if "use_none_hint" in data_batch and data_batch["use_none_hint"]: - setattr(uncondition, hint_key, None) - else: - setattr(uncondition, hint_key, latent_hint) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) - - cp_group = parallel_state.get_context_parallel_group() - latent_hint = getattr(condition, hint_key) - seq_dim = 3 if latent_hint.ndim == 6 else 2 - latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group) - setattr(condition, hint_key, latent_hint) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint) - - # not sure if this is consistent w the new distilled model? - setattr(condition, "base_model", self.model.base_model) - setattr(uncondition, "base_model", self.model.base_model) - if hasattr(self, "hint_encoders"): - self.model.net.hint_encoders = self.hint_encoders - - cp_enabled = self.net.is_context_parallel_enabled - if cp_enabled: - random_noise = split_inputs_cp(x=random_noise, seq_dim=2, cp_group=self.net.cp_group) - - condition.gt_latent = condition_latent - uncondition.gt_latent = condition_latent - - if self.net.is_context_parallel_enabled: - x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self._forward_distilled( - epsilon=random_noise, - condition=condition, - uncondition=uncondition, - guidance=guidance, - hint_key=hint_key, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - target_h=target_h, - target_w=target_w, - patch_h=patch_h, - patch_w=patch_w, - seed=seed, - inference_mode=True, - **kwargs, - ) - cp_enabled = self.net.is_context_parallel_enabled - if cp_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - return samples - - def _forward_distilled( - self, - epsilon: torch.Tensor, - condition: Any, - uncondition: Any, - guidance: float, - hint_key: str, - condition_video_augment_sigma_in_inference: float = 0.001, - target_h: int = 88, - target_w: int = 160, - patch_h: int = 88, - patch_w: int = 160, - seed: int = 1, - inference_mode: bool = True, - **kwargs, - ) -> torch.Tensor: - """Single forward pass for distilled models""" - B = epsilon.shape[0] # Batch dimension - - w, h = target_w, target_h - n_img_w = (w - 1) // patch_w + 1 - n_img_h = (h - 1) // patch_h + 1 - - overlap_size_w = overlap_size_h = 0 - if n_img_w > 1: - overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) - assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w - if n_img_h > 1: - overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) - assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h - - # Single denoising step at sigma_max - sigma_max = torch.tensor(self.sde.sigma_max).repeat(epsilon.size(0)).to(epsilon.device) - - # Direct network forward pass - no iterative sampling - with torch.no_grad(): - cond_x0 = self.denoise( - noise_x=epsilon * self.sde.sigma_max, # Scale noise to sigma_max - sigma=sigma_max, - condition=condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - return cond_x0 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_multi_camera_ctrl.py b/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_multi_camera_ctrl.py deleted file mode 100644 index 30379014eb30bcd9eb15919b689f4f0fc2a9dd2e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_multi_camera_ctrl.py +++ /dev/null @@ -1,300 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Dict, Optional, Tuple, TypeVar, Union - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import DataType, VideoConditionerWithCtrl -from cosmos_transfer1.diffusion.model.model_t2w import DiffusionT2WModel, broadcast_condition -from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel -from cosmos_transfer1.diffusion.model.model_v2w_multiview import DiffusionV2WMultiviewModel -from cosmos_transfer1.diffusion.module.parallel import broadcast, cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate - -T = TypeVar("T") -IS_PREPROCESSED_KEY = "is_preprocessed" - - -class MultiVideoDiffusionModelWithCtrl(DiffusionV2WMultiviewModel): - def build_model(self) -> torch.nn.ModuleDict: - log.info("Start creating base model") - base_model = super().build_model() - # initialize base model - config = self.config - self.load_base_model(base_model) - log.info("Done creating base model") - - log.info("Start creating ctrlnet model") - net = lazy_instantiate(self.config.net_ctrl) - conditioner = base_model.conditioner - logvar = base_model.logvar - # initialize controlnet encoder - model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) - - model.load_state_dict(base_model.state_dict(), strict=False) - - model.base_model = base_model - log.info("Done creating ctrlnet model") - - self.hint_key = self.config.hint_key["hint_key"] - return model - - @property - def base_net(self): - return self.model.base_model.net - - @property - def conditioner(self): - return self.model.conditioner - - def load_base_model(self, base_model) -> None: - config = self.config - if config.base_load_from is not None: - checkpoint_path = config.base_load_from["load_path"] - else: - checkpoint_path = "" - - if checkpoint_path: - log.info(f"Loading base model checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) - log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}") - - if "ema" in state_dict: - # Copy the base model weights from ema model. - log.info("Copying ema to base model") - base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} - elif "model" in state_dict: - # Copy the base model weights from reg model. - log.warning("Using non-EMA base model") - base_state_dict = state_dict["model"] - else: - log.info("Loading from an EMA only model") - base_state_dict = state_dict - base_model.load_state_dict(base_state_dict, strict=False) - log.info("Done loading the base model checkpoint.") - - def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: - x = data_batch[data_batch["hint_key"]] - latent = [] - # control input goes through tokenizer, which always takes 3-input channels - num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension - if num_conditions > 1 and self.config.hint_dropout_rate > 0: - if not cond_mask: # during inference, use hint_mask to indicate which conditions are used - cond_mask = self.config.hint_mask - else: - cond_mask = [True] * num_conditions - for idx in range(0, x.size(1), 3): - x_rgb = x[:, idx : idx + 3] # B C (V T) H W - if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image - x_rgb = torch.zeros_like(x_rgb) - latent.append(self.encode(x_rgb)) - latent = torch.cat(latent, dim=1) - return latent - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - seed: int = 1, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. - - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - # data_batch should be the one processed by self.get_data_and_condition - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - if "view_indices" in data_batch: - if hasattr(self, "vae"): - comp_factor = getattr(self.vae, "temporal_compression_factor", 8) - elif hasattr(self, "tokenizer"): - comp_factor = getattr(self.tokenizer, "temporal_compression_factor", 8) - else: - comp_factor = 8 - # n_frames = data_batch['num_frames'] - view_indices = rearrange(data_batch["view_indices"], "B (V T) -> B V T", V=self.n_views) - view_indices_B_V_0 = view_indices[:, :, :1] - view_indices_B_V_1T = view_indices[:, :, 1:-1:comp_factor] - view_indices_B_V_T = torch.cat([view_indices_B_V_0, view_indices_B_V_1T], dim=-1) - condition.view_indices_B_T = rearrange(view_indices_B_V_T, "B V T -> B (V T)", V=self.n_views) - condition.data_n_views = self.n_views - uncondition.view_indices_B_T = condition.view_indices_B_T - uncondition.data_n_views = self.n_views - - if condition_latent is None: - batch_size = data_batch["latent_hint"].shape[0] - condition_latent = torch.zeros(batch_size, *self.state_shape, **self.tensor_kwargs) - num_condition_t = 0 - condition_video_augment_sigma_in_inference = 1000 - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - - uncondition.video_cond_bool = True # Not do cfg on condition frames - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - - # Add extra conditions for ctrlnet. - latent_hint = data_batch["latent_hint"] - hint_key = data_batch["hint_key"] - setattr(condition, hint_key, latent_hint) - if "use_none_hint" in data_batch and data_batch["use_none_hint"]: - setattr(uncondition, hint_key, None) - else: - setattr(uncondition, hint_key, latent_hint) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): # and not self.is_image_batch(data_batch): - condition = broadcast_condition(condition, to_tp=False, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=False, to_cp=to_cp) - - cp_group = parallel_state.get_context_parallel_group() - latent_hint = getattr(condition, hint_key) - seq_dim = 3 if latent_hint.ndim == 6 else 2 - latent_hint = split_inputs_cp(latent_hint, seq_dim=seq_dim, cp_group=cp_group) - setattr(condition, hint_key, latent_hint) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint) - - setattr(condition, "base_model", self.model.base_model) - setattr(uncondition, "base_model", self.model.base_model) - if hasattr(self, "hint_encoders"): - self.model.net.hint_encoders = self.hint_encoders - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - **kwargs, - ) -> Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. - If this feature is stablized, we could consider to move this function to the base model. - - Args: - condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. - num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half - """ - - is_image_batch = False # self.is_image_batch(data_batch) - - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - else: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - x0_fn = self.get_x0_fn_from_batch( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ) - - if sigma_max is None: - sigma_max = self.sde.sigma_max - - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = broadcast(x_sigma_max, to_tp=False, to_cp=True) - x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) - - if self.net.is_context_parallel_enabled: - samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views).contiguous() - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - return samples diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_t2w.py b/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_t2w.py deleted file mode 100644 index 0b68460ca120ffab429c7b68460352580fd2ce4c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_t2w.py +++ /dev/null @@ -1,379 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, List, Set, Tuple - -import torch -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, CosmosCondition -from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.diffusion.diffusion.modules.denoiser_scaling import EDMScaling -from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import Sampler -from cosmos_transfer1.diffusion.diffusion.types import DenoisePrediction -from cosmos_transfer1.diffusion.module import parallel -from cosmos_transfer1.diffusion.module.blocks import FourierFeatures -from cosmos_transfer1.diffusion.module.pretrained_vae import BaseVAE -from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate - -# from cosmos_transfer1.diffusion.training.models.model import _broadcast -IS_PREPROCESSED_KEY = "is_preprocessed" -from enum import Enum - - -class DataType(Enum): - IMAGE = "image" - VIDEO = "video" - MIX = "mix" - - -class EDMSDE: - def __init__( - self, - sigma_max: float, - sigma_min: float, - ): - self.sigma_max = sigma_max - self.sigma_min = sigma_min - - -class DiffusionT2WModel(torch.nn.Module): - """Text-to-world diffusion model that generates video frames from text descriptions. - - This model implements a diffusion-based approach for generating videos conditioned on text input. - It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling, - and classifier-free guidance. - """ - - def __init__(self, config): - """Initialize the diffusion model. - - Args: - config: Configuration object containing model parameters and architecture settings - """ - super().__init__() - # Initialize trained_data_record with defaultdict, key: image, video, iteration - self.config = config - - self.precision = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - }[config.precision] - self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} - log.debug(f"DiffusionModel: precision {self.precision}") - # Timer passed to network to detect slow ranks. - # 1. set data keys and data information - self.sigma_data = config.sigma_data - self.state_shape = list(config.latent_shape) - self.setup_data_key() - - # 2. setup up diffusion processing and scaling~(pre-condition), sampler - self.sde = EDMSDE(sigma_max=80, sigma_min=0.0002) - self.sampler = Sampler() - self.scaling = EDMScaling(self.sigma_data) - self.tokenizer = None - self.model = None - - @property - def net(self): - return self.model.net - - @property - def conditioner(self): - return self.model.conditioner - - @property - def logvar(self): - return self.model.logvar - - def set_up_tokenizer(self, tokenizer_dir: str): - self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer) - self.tokenizer.load_weights(tokenizer_dir) - if hasattr(self.tokenizer, "reset_dtype"): - self.tokenizer.reset_dtype() - - @misc.timer("DiffusionModel: set_up_model") - def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format): - """Initialize the core model components including network, conditioner and logvar.""" - self.model = self.build_model() - self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) - - def build_model(self) -> torch.nn.ModuleDict: - """Construct the model's neural network components. - - Returns: - ModuleDict containing the network, conditioner and logvar components - """ - config = self.config - net = lazy_instantiate(config.net) - conditioner = lazy_instantiate(config.conditioner) - logvar = torch.nn.Sequential( - FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) - ) - - return torch.nn.ModuleDict( - { - "net": net, - "conditioner": conditioner, - "logvar": logvar, - } - ) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - """Encode input state into latent representation using VAE. - - Args: - state: Input tensor to encode - - Returns: - Encoded latent representation scaled by sigma_data - """ - return self.tokenizer.encode(state) * self.sigma_data - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """Decode latent representation back to pixel space using VAE. - - Args: - latent: Latent tensor to decode - - Returns: - Decoded tensor in pixel space - """ - return self.tokenizer.decode(latent / self.sigma_data) - - def setup_data_key(self) -> None: - """Configure input data keys for video and image data.""" - self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model - self.input_image_key = self.config.input_image_key - - def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: - """ - Performs denoising on the input noise data, noise level, and condition - - Args: - xt (torch.Tensor): The input noise data. - sigma (torch.Tensor): The noise level. - condition (CosmosCondition): conditional information, generated from self.conditioner - - Returns: - DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ - noise prediction (eps_pred) and optional confidence (logvar). - """ - - xt = xt.to(**self.tensor_kwargs) - sigma = sigma.to(**self.tensor_kwargs) - # get precondition for the network - c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) - - # forward pass through the network - net_output = self.net( - x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf - timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf - **condition.to_dict(), - ) - - logvar = self.model.logvar(c_noise) - x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) - - # get noise prediction based on sde - eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) - - return DenoisePrediction(x0_pred, eps_pred, logvar) - - -def robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: - """ - Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. - - Args: - tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). - src (int): The source rank for the broadcast. Defaults to 0. - - Returns: - torch.Tensor: The broadcasted tensor on all ranks. - """ - # First, broadcast the shape of the tensor - if distributed.get_rank() == src: - shape = torch.tensor(tensor.shape).cuda() - else: - shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() - if is_check_shape: - _verify_param_shape_across_processes(pg, [shape]) - torch.distributed.broadcast(shape, src, group=pg) - - # Resize the tensor on non-src ranks if necessary - if distributed.get_rank() != src: - tensor = tensor.new_empty(shape.tolist()).type_as(tensor) - - # Now broadcast the tensor data - torch.distributed.broadcast(tensor, src, group=pg) - - return tensor - - -def _broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: - """ - Broadcast the item from the minimum rank in the specified group(s). - Since global rank = tp_rank + cp_rank * tp_size + ... - First broadcast in the tp_group and then in the cp_group will - ensure that the item is broadcasted across ranks in cp_group and tp_group. - - Parameters: - - item: The item to broadcast (can be a torch.Tensor, str, or None). - - to_tp: Whether to broadcast to the tensor model parallel group. - - to_cp: Whether to broadcast to the context parallel group. - """ - if not parallel_state.is_initialized(): - return item - tp_group = parallel_state.get_tensor_model_parallel_group() - cp_group = parallel_state.get_context_parallel_group() - - to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 - to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 - - if to_tp: - min_tp_rank = min(get_process_group_ranks(tp_group)) - - if to_cp: - min_cp_rank = min(get_process_group_ranks(cp_group)) - - if isinstance(item, torch.Tensor): # assume the device is cuda - if to_tp: - item = robust_broadcast(item, min_tp_rank, tp_group) - if to_cp: - item = robust_broadcast(item, min_cp_rank, cp_group) - elif item is not None: - broadcastable_list = [item] - if to_tp: - broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) - if to_cp: - broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) - - item = broadcastable_list[0] - return item - - -class DistillT2WModel(DiffusionT2WModel): - """Base Video Distillation Model.""" - - def __init__(self, config): - super().__init__(config) - - def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor, CosmosCondition]: - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - input_key = self.input_data_key # by default it is video key - is_image_batch = self.is_image_batch(data_batch) - is_video_batch = not is_image_batch - - # Broadcast data and condition across TP and CP groups. - # sort keys to make sure the order is same, IMPORTANT! otherwise, nccl will hang! - local_keys = sorted(list(data_batch.keys())) - for key in local_keys: - data_batch[key] = _broadcast(data_batch[key], to_tp=True, to_cp=is_video_batch) - - if is_image_batch: - input_key = self.input_image_key - - # Latent state - raw_state = data_batch[input_key] - latent_state = self.encode(raw_state).contiguous() - - # Condition - condition = self.conditioner(data_batch) - if is_image_batch: - condition.data_type = DataType.IMAGE - else: - condition.data_type = DataType.VIDEO - - # VAE has randomness. CP/TP group should have the same encoded output. - - latent_state = _broadcast(latent_state, to_tp=True, to_cp=is_video_batch) - condition = broadcast_condition(condition, to_tp=True, to_cp=is_video_batch) - - def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: - """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. - Another comes from a dataloader which we by default assumes as video_data for video model training. - """ - is_image = self.input_image_key in data_batch - is_video = self.input_data_key in data_batch - assert ( - is_image != is_video - ), "Only one of the input_image_key or input_data_key should be present in the data_batch." - return is_image - - def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: - """ - Normalizes video data in-place on a CUDA device to reduce data loading overhead. - - This function modifies the video data tensor within the provided data_batch dictionary - in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1]. - - Warning: - A warning is issued if the data has not been previously normalized. - - Args: - data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key. - This tensor is expected to be on a CUDA device and have dtype of torch.uint8. - - Side Effects: - Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place. - - Note: - This operation is performed directly on the CUDA device to avoid the overhead associated - with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device - and has the correct dtype (torch.uint8) to avoid unexpected behaviors. - """ - input_key = self.input_data_key if input_key is None else input_key - # only handle video batch - if input_key in data_batch: - # Check if the data has already been normalized and avoid re-normalizing - if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: - assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." - assert torch.all( - (data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) - ), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" - else: - assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." - data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 - data_batch[IS_PREPROCESSED_KEY] = True - - def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: - input_key = self.input_image_key if input_key is None else input_key - if input_key in data_batch: - # Check if the data has already been augmented and avoid re-augmenting - if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: - assert ( - data_batch[input_key].shape[2] == 1 - ), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" - return - else: - data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() - data_batch[IS_PREPROCESSED_KEY] = True - - -def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: - condition_kwargs = {} - for k, v in condition.to_dict().items(): - if isinstance(v, torch.Tensor): - assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" - condition_kwargs[k] = parallel.broadcast(v, to_tp=to_tp, to_cp=to_cp) - condition = type(condition)(**condition_kwargs) - return condition diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_v2w.py b/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_v2w.py deleted file mode 100644 index 906cade4bea1a56e6028e49e402419f710f1fa82..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_v2w.py +++ /dev/null @@ -1,573 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union - -import torch -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import VideoExtendCondition -from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig -from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.diffusion.model.model_t2w import DataType, DiffusionT2WModel, DistillT2WModel -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.utils import log, misc - - -@dataclass -class VideoDenoisePrediction: - x0: torch.Tensor # clean data prediction - eps: Optional[torch.Tensor] = None # noise prediction - logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty - xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in - x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent - - -class DiffusionV2WModel(DiffusionT2WModel): - def __init__(self, config): - super().__init__(config) - - def augment_conditional_latent_frames( - self, - condition: VideoExtendCondition, - cfg_video_cond_bool: VideoCondBoolConfig, - gt_latent: Tensor, - condition_video_augment_sigma_in_inference: float = 0.001, - sigma: Tensor = None, - seed: int = 1, - ) -> Union[VideoExtendCondition, Tensor]: - """Augments the conditional frames with noise during inference. - - Args: - condition (VideoExtendCondition): condition object - condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. - condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. - cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config - gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - sigma (Tensor): noise level for the generation region - seed (int): random seed for reproducibility - Returns: - VideoExtendCondition: updated condition object - condition_video_augment_sigma: sigma for the condition region, feed to the network - augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W - - """ - - # Inference only, use fixed sigma for the condition region - assert ( - condition_video_augment_sigma_in_inference is not None - ), "condition_video_augment_sigma_in_inference should be provided" - augment_sigma = condition_video_augment_sigma_in_inference - - if augment_sigma >= sigma.flatten()[0]: - # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. - # This is achieved by setting all region as `generation`, i.e. value=0 - log.debug("augment_sigma larger than sigma or other frame, remove condition") - condition.condition_video_indicator = condition.condition_video_indicator * 0 - - B = gt_latent.shape[0] - augment_sigma = torch.full((B,), augment_sigma, **self.tensor_kwargs) - - # Now apply the augment_sigma to the gt_latent - - noise = misc.arch_invariant_rand( - gt_latent.shape, - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - - augment_latent = gt_latent + noise * augment_sigma[:, None, None, None, None] - - _, _, c_in_augment, _ = self.scaling(sigma=augment_sigma) - - # Multiply the whole latent with c_in_augment - augment_latent_cin = batch_mul(augment_latent, c_in_augment) - - # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect - _, _, c_in, _ = self.scaling(sigma=sigma) - augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) - - return condition, augment_latent_cin - - def denoise( - self, - noise_x: Tensor, - sigma: Tensor, - condition: VideoExtendCondition, - condition_video_augment_sigma_in_inference: float = 0.001, - seed: int = 1, - ) -> VideoDenoisePrediction: - """Denoises input tensor using conditional video generation. - - Args: - noise_x (Tensor): Noisy input tensor. - sigma (Tensor): Noise level. - condition (VideoExtendCondition): Condition for denoising. - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - seed (int): Random seed for reproducibility - Returns: - VideoDenoisePrediction containing: - - x0: Denoised prediction - - eps: Noise prediction - - logvar: Log variance of noise prediction - - xt: Input before c_in multiplication - - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth - """ - - assert ( - condition.gt_latent is not None - ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" - gt_latent = condition.gt_latent - cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool - - condition_latent = gt_latent - - # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed - condition, augment_latent = self.augment_conditional_latent_frames( - condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed - ) - condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] - - if parallel_state.get_context_parallel_world_size() > 1: - cp_group = parallel_state.get_context_parallel_group() - condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) - augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) - gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) - - # Compose the model input with condition region (augment_latent) and generation region (noise_x) - new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x - # Call the abse model - denoise_pred = super().denoise(new_noise_xt, sigma, condition) - - x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 - - x0_pred = x0_pred_replaced - - return VideoDenoisePrediction( - x0=x0_pred, - eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), - logvar=denoise_pred.logvar, - xt=new_noise_xt, - x0_pred_replaced=x0_pred_replaced, - ) - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: Optional[float] = None, - **kwargs, - ) -> Tensor: - """Generates video samples conditioned on input frames. - - Args: - data_batch: Input data dictionary - guidance: Classifier-free guidance scale - seed: Random seed for reproducibility - state_shape: Shape of output tensor (defaults to model's state shape) - n_sample: Number of samples to generate (defaults to batch size) - is_negative_prompt: Whether to use negative prompting - num_steps: Number of denoising steps - condition_latent: Conditioning frames tensor (B,C,T,H,W) - num_condition_t: Number of frames to condition on - condition_video_augment_sigma_in_inference: Noise level for condition augmentation - add_input_frames_guidance: Whether to apply guidance to input frames - x_sigma_max: Maximum noise level tensor - - Returns: - Generated video samples tensor - """ - - if n_sample is None: - input_key = self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - assert condition_latent is not None, "condition_latent should be provided" - - x0_fn = self.get_x0_fn_from_batch_with_condition_latent( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - add_input_frames_guidance=add_input_frames_guidance, - seed=seed, - ) - if sigma_max is None: - sigma_max = self.sde.sigma_max - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) - - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - return samples - - def get_x0_fn_from_batch_with_condition_latent( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - seed: int = 1, - ) -> Callable: - """Creates denoising function for conditional video generation. - - Args: - data_batch: Input data dictionary - guidance: Classifier-free guidance scale - is_negative_prompt: Whether to use negative prompting - condition_latent: Conditioning frames tensor (B,C,T,H,W) - num_condition_t: Number of frames to condition on - condition_video_augment_sigma_in_inference: Noise level for condition augmentation - add_input_frames_guidance: Whether to apply guidance to input frames - seed: Random seed for reproducibility - - Returns: - Function that takes noisy input and noise level and returns denoised prediction - """ - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - - uncondition.video_cond_bool = False if add_input_frames_guidance else True - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def add_condition_video_indicator_and_video_input_mask( - self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None - ) -> VideoExtendCondition: - """Adds conditioning masks to VideoExtendCondition object. - - Creates binary indicators and input masks for conditional video generation. - - Args: - latent_state: Input latent tensor (B,C,T,H,W) - condition: VideoExtendCondition object to update - num_condition_t: Number of frames to condition on - - Returns: - Updated VideoExtendCondition with added masks: - - condition_video_indicator: Binary tensor marking condition regions - - condition_video_input_mask: Input mask for network - - gt_latent: Ground truth latent tensor - """ - T = latent_state.shape[2] - latent_dtype = latent_state.dtype - condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( - latent_dtype - ) # 1 for condition region - - # Only in inference to decide the condition region - assert num_condition_t is not None, "num_condition_t should be provided" - assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" - log.debug( - f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" - ) - condition_video_indicator[:, :, :num_condition_t] += 1.0 - - condition.gt_latent = latent_state - condition.condition_video_indicator = condition_video_indicator - - B, C, T, H, W = latent_state.shape - # Create additional input_mask channel, this will be concatenated to the input of the network - # See design doc section (Implementation detail A.1 and A.2) for visualization - ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - assert condition.video_cond_bool is not None, "video_cond_bool should be set" - - # The input mask indicate whether the input is conditional region or not - if condition.video_cond_bool: # Condition one given video frames - condition.condition_video_input_mask = ( - condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding - ) - else: # Unconditional case, use for cfg - condition.condition_video_input_mask = zeros_padding - - return condition - - -class DistillV2WModel(DistillT2WModel): - """ControlNet Video2World Distillation Model.""" - - def augment_conditional_latent_frames( - self, - condition: VideoExtendCondition, - cfg_video_cond_bool: VideoCondBoolConfig, - gt_latent: Tensor, - condition_video_augment_sigma_in_inference: float = 0.001, - sigma: Tensor = None, - seed: int = 1, - ) -> Union[VideoExtendCondition, Tensor]: - """Augments the conditional frames with noise during inference. - - Args: - condition (VideoExtendCondition): condition object - condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. - condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. - cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config - gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - sigma (Tensor): noise level for the generation region - seed (int): random seed for reproducibility - Returns: - VideoExtendCondition: updated condition object - condition_video_augment_sigma: sigma for the condition region, feed to the network - augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W - - """ - # Inference only, use fixed sigma for the condition region - assert ( - condition_video_augment_sigma_in_inference is not None - ), "condition_video_augment_sigma_in_inference should be provided" - augment_sigma = condition_video_augment_sigma_in_inference - - if augment_sigma >= sigma.flatten()[0]: - # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. - # This is achieved by setting all region as `generation`, i.e. value=0 - log.debug("augment_sigma larger than sigma or other frame, remove condition") - condition.condition_video_indicator = condition.condition_video_indicator * 0 - - B = gt_latent.shape[0] - augment_sigma = torch.full((B,), augment_sigma, **self.tensor_kwargs) - - # Now apply the augment_sigma to the gt_latent - - noise = misc.arch_invariant_rand( - gt_latent.shape, - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - - augment_latent = gt_latent + noise * augment_sigma.view(B, 1, 1, 1, 1) - _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) - - if cfg_video_cond_bool.condition_on_augment_sigma: # model takes augment_sigma as input - if condition.condition_video_indicator.sum() > 0: # has condition frames - condition.condition_video_augment_sigma = c_noise_augment - else: # no condition frames - condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) - - # Multiply the whole latent with c_in_augment - augment_latent_cin = batch_mul(augment_latent, c_in_augment) - - # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect - _, _, c_in, _ = self.scaling(sigma=sigma) - augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) - - return condition, augment_latent_cin - - def drop_out_condition_region( - self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig - ) -> Tensor: - """Use for CFG on input frames, we drop out the conditional region - There are two option: - 1. when we dropout, we set the region to be zero - 2. when we dropout, we set the region to be noise_x - """ - # Unconditional case, use for cfg - if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": - # Set the condition location input to be zero - augment_latent_drop = torch.zeros_like(augment_latent) - elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": - # Set the condition location input to be noise_x, i.e., same as base model training - augment_latent_drop = noise_x - else: - raise NotImplementedError( - f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" - ) - return augment_latent_drop - - def denoise( - self, - noise_x: Tensor, - sigma: Tensor, - condition: VideoExtendCondition, - condition_video_augment_sigma_in_inference: float = 0.001, - seed: int = 1, - ) -> VideoDenoisePrediction: - """Denoises input tensor using conditional video generation. - - Args: - noise_x (Tensor): Noisy input tensor. - sigma (Tensor): Noise level. - condition (VideoExtendCondition): Condition for denoising. - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - seed (int): Random seed for reproducibility - Returns: - VideoDenoisePrediction containing: - - x0: Denoised prediction - - eps: Noise prediction - - logvar: Log variance of noise prediction - - xt: Input before c_in multiplication - - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth - """ - inputs_to_check = [noise_x, sigma, condition.gt_latent] - for i, tensor in enumerate(inputs_to_check): - if torch.isnan(tensor).any(): - print(f"NaN found in input {i}") - assert ( - condition.gt_latent is not None - ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" - gt_latent = condition.gt_latent - cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool - - condition_latent = gt_latent - - # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed - condition, augment_latent = self.augment_conditional_latent_frames( - condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed - ) - condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] - - if parallel_state.get_context_parallel_world_size() > 1: - cp_group = parallel_state.get_context_parallel_group() - condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) - augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) - gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) - - if not condition.video_cond_bool: - # Unconditional case, drop out the condition region - augment_latent = self.drop_out_condition_region(augment_latent, xt, cfg_video_cond_bool) - - # Compose the model input with condition region (augment_latent) and generation region (noise_x) - new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x - # Call the abse model - denoise_pred = super().denoise(new_noise_xt, sigma, condition) - - x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 - - x0_pred = x0_pred_replaced - - return VideoDenoisePrediction( - x0=x0_pred, - eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), - logvar=denoise_pred.logvar, - xt=new_noise_xt, - x0_pred_replaced=x0_pred_replaced, - ) - - def add_condition_video_indicator_and_video_input_mask( - self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None - ) -> VideoExtendCondition: - """Adds conditioning masks to VideoExtendCondition object. - - Creates binary indicators and input masks for conditional video generation. - - Args: - latent_state: Input latent tensor (B,C,T,H,W) - condition: VideoExtendCondition object to update - num_condition_t: Number of frames to condition on - - Returns: - Updated VideoExtendCondition with added masks: - - condition_video_indicator: Binary tensor marking condition regions - - condition_video_input_mask: Input mask for network - - gt_latent: Ground truth latent tensor - """ - T = latent_state.shape[2] - latent_dtype = latent_state.dtype - condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( - latent_dtype - ) # 1 for condition region - - # Only in inference to decide the condition region - assert num_condition_t is not None, "num_condition_t should be provided" - assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" - log.debug( - f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" - ) - condition_video_indicator[:, :, :num_condition_t] += 1.0 - - condition.gt_latent = latent_state - condition.condition_video_indicator = condition_video_indicator - - B, C, T, H, W = latent_state.shape - # Create additional input_mask channel, this will be concatenated to the input of the network - ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - assert condition.video_cond_bool is not None, "video_cond_bool should be set" - - # The input mask indicate whether the input is conditional region or not - if condition.video_cond_bool: # Condition one given video frames - condition.condition_video_input_mask = ( - condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding - ) - else: # Unconditional case, use for cfg - condition.condition_video_input_mask = zeros_padding - - return condition diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_v2w_multiview.py b/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_v2w_multiview.py deleted file mode 100644 index 43c9e9d63edb3427671d561d8dffc4e28e317023..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/model/model_v2w_multiview.py +++ /dev/null @@ -1,382 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from dataclasses import dataclass -from typing import Callable, Dict, Optional, Tuple, Union - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import VideoExtendCondition -from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig -from cosmos_transfer1.diffusion.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.diffusion.model.model_t2w import broadcast_condition -from cosmos_transfer1.diffusion.model.model_v2w import DiffusionV2WModel -from cosmos_transfer1.diffusion.module.parallel import broadcast, cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.utils import log, misc - - -def deepcopy_no_copy_model(obj): - """ - We need to create a copy of the condition construct such that condition masks can be adjusted dynamically, but - the controlnet encoder plug-in also uses the condition construct to pass along the base_model object which cannot be - deep-copied, hence this funciton - """ - if hasattr(obj, "base_model") and obj.base_model is not None: - my_base_model = obj.base_model - obj.base_model = None - copied_obj = copy.deepcopy(obj) - copied_obj.base_model = my_base_model - obj.base_model = my_base_model - else: - copied_obj = copy.deepcopy(obj) - return copied_obj - - -@dataclass -class VideoDenoisePrediction: - x0: torch.Tensor # clean data prediction - eps: Optional[torch.Tensor] = None # noise prediction - logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty - xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in - x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent - - -class DiffusionV2WMultiviewModel(DiffusionV2WModel): - def __init__(self, config): - super().__init__(config) - self.n_views = config.n_views - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - encoded_state = self.tokenizer.encode(state) - encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data - return encoded_state - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - decoded_state = self.tokenizer.decode(latent / self.sigma_data) - decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - return decoded_state - - def denoise( - self, - noise_x: Tensor, - sigma: Tensor, - condition: VideoExtendCondition, - condition_video_augment_sigma_in_inference: float = 0.001, - seed: int = 1, - ) -> VideoDenoisePrediction: - """Denoises input tensor using conditional video generation. - - Args: - noise_x (Tensor): Noisy input tensor. - sigma (Tensor): Noise level. - condition (VideoExtendCondition): Condition for denoising. - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - seed (int): Random seed for reproducibility - Returns: - VideoDenoisePrediction containing: - - x0: Denoised prediction - - eps: Noise prediction - - logvar: Log variance of noise prediction - - xt: Input before c_in multiplication - - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth - """ - - assert ( - condition.gt_latent is not None - ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" - condition = deepcopy_no_copy_model(condition) - gt_latent = condition.gt_latent - cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool - - condition_latent = gt_latent - - # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed - condition, augment_latent = self.augment_conditional_latent_frames( - condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed - ) - condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] - - if parallel_state.get_context_parallel_world_size() > 1: - cp_group = parallel_state.get_context_parallel_group() - condition_video_indicator = rearrange( - condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views - ) - augment_latent = rearrange(augment_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - gt_latent = rearrange(gt_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - if getattr(condition, "view_indices_B_T", None) is not None: - view_indices_B_V_T = rearrange(condition.view_indices_B_T, "B (V T) -> (B V) T", V=self.n_views) - view_indices_B_V_T = split_inputs_cp(view_indices_B_V_T, seq_dim=1, cp_group=cp_group) - condition.view_indices_B_T = rearrange(view_indices_B_V_T, "(B V) T -> B (V T)", V=self.n_views) - condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) - augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) - gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) - - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views - ) - augment_latent = rearrange(augment_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - gt_latent = rearrange(gt_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - # Compose the model input with condition region (augment_latent) and generation region (noise_x) - new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x - # Call the abse model - denoise_pred = super(DiffusionV2WModel, self).denoise(new_noise_xt, sigma, condition) - - x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 - - x0_pred = x0_pred_replaced - - return VideoDenoisePrediction( - x0=x0_pred, - eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), - logvar=denoise_pred.logvar, - xt=new_noise_xt, - x0_pred_replaced=x0_pred_replaced, - ) - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: Optional[float] = None, - **kwargs, - ) -> Tensor: - """Generates video samples conditioned on input frames. - - Args: - data_batch: Input data dictionary - guidance: Classifier-free guidance scale - seed: Random seed for reproducibility - state_shape: Shape of output tensor (defaults to model's state shape) - n_sample: Number of samples to generate (defaults to batch size) - is_negative_prompt: Whether to use negative prompting - num_steps: Number of denoising steps - condition_latent: Conditioning frames tensor (B,C,T,H,W) - num_condition_t: Number of frames to condition on - condition_video_augment_sigma_in_inference: Noise level for condition augmentation - add_input_frames_guidance: Whether to apply guidance to input frames - x_sigma_max: Maximum noise level tensor - - Returns: - Generated video samples tensor - """ - - if n_sample is None: - input_key = self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - assert condition_latent is not None, "condition_latent should be provided" - - x0_fn = self.get_x0_fn_from_batch_with_condition_latent( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - add_input_frames_guidance=add_input_frames_guidance, - seed=seed, - ) - if sigma_max is None: - sigma_max = self.sde.sigma_max - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max) - - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - return samples - - def get_x0_fn_from_batch_with_condition_latent( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - seed: int = 1, - ) -> Callable: - """Creates denoising function for conditional video generation. - - Args: - data_batch: Input data dictionary - guidance: Classifier-free guidance scale - is_negative_prompt: Whether to use negative prompting - condition_latent: Conditioning frames tensor (B,C,T,H,W) - num_condition_t: Number of frames to condition on - condition_video_augment_sigma_in_inference: Noise level for condition augmentation - add_input_frames_guidance: Whether to apply guidance to input frames - seed: Random seed for reproducibility - - Returns: - Function that takes noisy input and noise level and returns denoised prediction - """ - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - if "view_indices" in data_batch: - comp_factor = self.vae.temporal_compression_factor - view_indices = rearrange(data_batch["view_indices"], "B (V T) -> B V T", V=self.n_views) - view_indices_B_V_0 = view_indices[:, :, :1] - view_indices_B_V_1T = view_indices[:, :, 1:-1:comp_factor] - view_indices_B_V_T = torch.cat([view_indices_B_V_0, view_indices_B_V_1T], dim=-1) - condition.view_indices_B_T = rearrange(view_indices_B_V_T, "B V T -> B (V T)", V=self.n_views) - uncondition.view_indices_B_T = condition.view_indices_B_T - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - - uncondition.video_cond_bool = False if add_input_frames_guidance else True - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed=seed, - ).x0_pred_replaced - - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def add_condition_video_indicator_and_video_input_mask( - self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None - ) -> VideoExtendCondition: - """Adds conditioning masks to VideoExtendCondition object. - - Creates binary indicators and input masks for conditional video generation. - - Args: - latent_state: Input latent tensor (B,C,T,H,W) - condition: VideoExtendCondition object to update - num_condition_t: Number of frames to condition on - - Returns: - Updated VideoExtendCondition with added masks: - - condition_video_indicator: Binary tensor marking condition regions - - condition_video_input_mask: Input mask for network - - gt_latent: Ground truth latent tensor - """ - T = latent_state.shape[2] - latent_dtype = latent_state.dtype - condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( - latent_dtype - ) # 1 for condition region - condition_video_indicator = rearrange(condition_video_indicator, "B C (V T) H W -> B V C T H W", V=self.n_views) - if self.config.conditioner.video_cond_bool.condition_location == "first_cam": - # condition on first cam - condition_video_indicator[:, 0, :, :, :, :] += 1.0 - - elif self.config.conditioner.video_cond_bool.condition_location.startswith("fixed_cam_and_first_n"): - # condition on a list of cameras specified through the string - cond_vids = [int(c) for c in self.config.conditioner.video_cond_bool.condition_location.split("_")[5:]] - - for vidx in cond_vids: - condition_video_indicator[:, vidx, :, :, :, :] += 1.0 - # also condition on first n_condition_t frames - condition_video_indicator[:, :, :, :num_condition_t] += 1.0 - condition_video_indicator = condition_video_indicator.clamp(max=1.0) - - elif self.config.conditioner.video_cond_bool.condition_location.startswith("fixed_cam"): - # condition on a list of cameras specified through the string - cond_vids = [int(c) for c in self.config.conditioner.video_cond_bool.condition_location.split("_")[2:]] - - for vidx in cond_vids: - condition_video_indicator[:, vidx, :, :, :, :] += 1.0 - condition_video_indicator = torch.clamp(condition_video_indicator, 0, 1) - - elif self.config.conditioner.video_cond_bool.condition_location == "first_cam_and_first_n": - # condition on first cam - condition_video_indicator[:, 0, :, :, :, :] += 1.0 - condition_video_indicator[:, :, :, :num_condition_t] += 1.0 - condition_video_indicator = condition_video_indicator.clamp(max=1.0) - else: - raise NotImplementedError( - f"condition_location {self.config.conditioner.video_cond_bool.condition_location } not implemented" - ) - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> B C (V T) H W", V=self.n_views - ) - - condition.gt_latent = latent_state - condition.condition_video_indicator = condition_video_indicator - - B, C, T, H, W = latent_state.shape - # Create additional input_mask channel, this will be concatenated to the input of the network - ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - assert condition.video_cond_bool is not None, "video_cond_bool should be set" - - # The input mask indicate whether the input is conditional region or not - if condition.video_cond_bool: # Condition one given video frames - condition.condition_video_input_mask = ( - condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding - ) - else: # Unconditional case, use for cfg - condition.condition_video_input_mask = zeros_padding - - return condition diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/module/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/module/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/module/attention.py b/cosmos-transfer1/cosmos_transfer1/diffusion/module/attention.py deleted file mode 100644 index 3e9e30f9efc6cd4e01a5a48a190f628785efcc46..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/module/attention.py +++ /dev/null @@ -1,500 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional - -import numpy as np -import torch -import transformer_engine as te -from einops import rearrange -from torch import Tensor, nn -from torch.utils.checkpoint import checkpoint -from transformer_engine.pytorch.attention.dot_product_attention.dot_product_attention import DotProductAttention -from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb - -# ---------------------- Feed Forward Network ----------------------- - - -class FeedForward(nn.Module): - """ - Transformer FFN with optional gating - - Parameters: - d_model (int): Dimensionality of input features. - d_ff (int): Dimensionality of the hidden layer. - dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1. - activation (callable, optional): The activation function applied after the first linear layer. - Defaults to nn.ReLU(). - is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer. - Defaults to False. - bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True. - - Example: - >>> ff = FeedForward(d_model=512, d_ff=2048) - >>> x = torch.randn(64, 10, 512) # Example input tensor - >>> output = ff(x) - >>> print(output.shape) # Expected shape: (64, 10, 512) - """ - - def __init__( - self, - d_model: int, - d_ff: int, - dropout: float = 0.1, - activation=nn.ReLU(), - is_gated: bool = False, - bias: bool = False, - ) -> None: - super().__init__() - - self.layer1 = nn.Linear(d_model, d_ff, bias=bias) - self.layer2 = nn.Linear(d_ff, d_model, bias=bias) - - self.dropout = nn.Dropout(dropout) - self.activation = activation - self.is_gated = is_gated - if is_gated: - self.linear_gate = nn.Linear(d_model, d_ff, bias=False) - - def forward(self, x: torch.Tensor): - g = self.activation(self.layer1(x)) - if self.is_gated: - x = g * self.linear_gate(x) - else: - x = g - assert self.dropout.p == 0.0, "we skip dropout" - return self.layer2(x) - - -class GPT2FeedForward(FeedForward): - def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False): - super().__init__( - d_model=d_model, - d_ff=d_ff, - dropout=dropout, - activation=nn.GELU(), - is_gated=False, - bias=bias, - ) - - def forward(self, x: torch.Tensor): - assert self.dropout.p == 0.0, "we skip dropout" - - x = self.layer1(x) - - def activation_layer2_forward(x): - x = self.activation(x) - x = self.layer2(x) - return x - - x = checkpoint(activation_layer2_forward, x, use_reentrant=False) - return x - - -# ---------------------- Normalization Layer ----------------------- - - -def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: - """ - Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. - - Args: - x (torch.Tensor): The input tensor to normalize. - dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. - eps (float, optional): A small constant to ensure numerical stability during division. - - Returns: - torch.Tensor: The normalized tensor. - """ - if dim is None: - dim = list(range(1, x.ndim)) - norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) - norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) - return x / norm.to(x.dtype) - - -def get_normalization(name: str, channels: int): - if name == "I": - return nn.Identity() - elif name == "R": - return te.pytorch.RMSNorm(channels, eps=1e-6) - else: - raise ValueError(f"Normalization {name} not found") - - -class BaseAttentionOp(nn.Module): - def __init__(self): - super().__init__() - - -class RegionalAttentionOp(BaseAttentionOp): - def __init__( - self, - heads, - dim_head, - num_gqa_groups=None, - attention_dropout=0, - qkv_format="bshd", - attn_mask_type="no_mask", - tp_size=1, - tp_group=None, - sequence_parallel=False, - ): - super().__init__() - self.heads = heads - self.dim_head = dim_head - self.qkv_format = qkv_format - self.tp_size = tp_size - self.scale = dim_head**-0.5 - self.attention_dropout = attention_dropout - self.sequence_parallel = sequence_parallel - self.tp_group = tp_group - self.dot_product_attention = DotProductAttention( - self.heads, - self.dim_head, - num_gqa_groups=num_gqa_groups, - attention_dropout=attention_dropout, - qkv_format=qkv_format, - attn_mask_type=attn_mask_type, - tp_size=tp_size, - tp_group=tp_group, - sequence_parallel=sequence_parallel, - ) - - def forward( - self, - q, - k, - v, - regional_k=None, - regional_v=None, - region_masks=None, - core_attention_bias_type="no_bias", - core_attention_bias=None, - ): - # Early return for non-regional case - if regional_k is None or regional_v is None or region_masks is None: - return self.dot_product_attention( - q, - k, - v, - attention_mask=None, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - ) - # Get dimensions - is_bshd = self.qkv_format == "bshd" - if is_bshd: - batch_size, seq_len, num_heads, head_dim = q.shape - else: - seq_len, batch_size, num_heads, head_dim = q.shape - - # Process region masks - processed_masks = [] - prompt_len = k.shape[1] if is_bshd else k.shape[0] - num_regions = len(regional_k) - - def preprocess_mask(mask: Tensor) -> Tensor: - mask = mask.permute(3, 0, 1, 2) - B, T, H, W = mask.shape - mask = mask.unsqueeze(1) # dummy unsqueeze since trilinear interpolation expects 5D - - mask_i = [ - torch.nn.functional.interpolate( - mask[:, :, :1, :, :], - size=(1, 44, 80), - mode="trilinear", - align_corners=False, - ) - ] - for wi in range(1, T, 8): - mask_i += [ - torch.nn.functional.interpolate( - mask[:, :, wi : wi + 8, :, :], - size=(1, 44, 80), - mode="trilinear", - align_corners=False, - ) - ] - assert len(mask_i) == 16 - mask = torch.cat(mask_i, dim=2) - mask = mask.squeeze(1) - return (mask > 0.5).float() - - for i in range(num_regions): - mask = region_masks[i] - mask = mask.to(q.device) - if mask.shape[0] != seq_len: - mask = preprocess_mask(mask) - mask = rearrange(mask, "b t h w -> b (t h w)") - processed_masks.append(mask) - - hidden_seq_len = seq_len - regional_attention_mask = torch.zeros( - (batch_size, hidden_seq_len, (num_regions + 1) * prompt_len), device=q.device, dtype=torch.bool - ) - start_idx = 0 - for i, mask in enumerate(processed_masks): - regional_attention_mask[:, :, (i + 1) * prompt_len : (i + 2) * prompt_len] = mask.unsqueeze(-1).bool() - - regional_masks_tensor = torch.stack(processed_masks, dim=-1).bool() # [B, S, R] - global_mask = (regional_masks_tensor.sum(dim=-1) == 0).unsqueeze(-1).bool() # [B, S, 1] - regional_attention_mask[:, :, :prompt_len] = global_mask - combined_k = torch.cat([k] + regional_k, dim=0) - combined_v = torch.cat([v] + regional_v, dim=0) - - attn_bias = torch.zeros_like(regional_attention_mask, dtype=torch.float32) - attn_bias = attn_bias.masked_fill(~regional_attention_mask, float("-inf")) - attn_bias = attn_bias.unsqueeze(1).expand(-1, num_heads, -1, -1) - output = self.dot_product_attention( - q, - combined_k, - combined_v, - attention_mask=None, - core_attention_bias_type="post_scale_bias", - core_attention_bias=attn_bias, - ) - - base_ratio = 0.5 # signifies the weight of the global prompt - if base_ratio is not None: - base_output = self.dot_product_attention( - q, - k, - v, - attention_mask=None, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - ) - output = output * (1 - base_ratio) + base_output * base_ratio - - if self.tp_size > 1 and not self.sequence_parallel: - torch.distributed.all_reduce(output, group=self.tp_group) - - return output - - -class Attention(nn.Module): - """ - Generalized attention impl. - - Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided. - If `context_dim` is None, self-attention is assumed. - - Parameters: - query_dim (int): Dimension of each query vector. - context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed. - heads (int, optional): Number of attention heads. Defaults to 8. - dim_head (int, optional): Dimension of each head. Defaults to 64. - dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0. - attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default. - qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False. - out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False. - qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections. - Defaults to "SSI". - qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections. - Defaults to 'per_head'. Only support 'per_head'. - - Examples: - >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1) - >>> query = torch.randn(10, 128) # Batch size of 10 - >>> context = torch.randn(10, 256) # Batch size of 10 - >>> output = attn(query, context) # Perform the attention operation - - Note: - https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - """ - - def __init__( - self, - query_dim: int, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - attn_op: Optional[BaseAttentionOp] = None, - qkv_bias: bool = False, - out_bias: bool = False, - qkv_norm: str = "SSI", - qkv_norm_mode: str = "per_head", - backend: str = "transformer_engine", - qkv_format: str = "sbhd", - ) -> None: - super().__init__() - - self.is_selfattn = context_dim is None # self attention - - inner_dim = dim_head * heads - context_dim = query_dim if context_dim is None else context_dim - - self.heads = heads - self.dim_head = dim_head - self.qkv_norm_mode = qkv_norm_mode - self.qkv_format = qkv_format - - if self.qkv_norm_mode == "per_head": - norm_dim = dim_head - else: - raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") - - self.backend = backend - - self.to_q = nn.Sequential( - nn.Linear(query_dim, inner_dim, bias=qkv_bias), - get_normalization(qkv_norm[0], norm_dim), - ) - self.to_k = nn.Sequential( - nn.Linear(context_dim, inner_dim, bias=qkv_bias), - get_normalization(qkv_norm[1], norm_dim), - ) - self.to_v = nn.Sequential( - nn.Linear(context_dim, inner_dim, bias=qkv_bias), - get_normalization(qkv_norm[2], norm_dim), - ) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim, bias=out_bias), - nn.Dropout(dropout), - ) - - if attn_op: # use what is given - self.attn_op = attn_op - elif self.backend == "transformer_engine": - self.attn_op: BaseAttentionOp = DotProductAttention( - self.heads, - self.dim_head, - num_gqa_groups=self.heads, - attention_dropout=0, - qkv_format=qkv_format, - attn_mask_type="no_mask", - sequence_parallel=False, - ) - self.regional_attn_op = RegionalAttentionOp( - self.heads, - self.dim_head, - num_gqa_groups=self.heads, - attention_dropout=0, - qkv_format=qkv_format, - attn_mask_type="arbitrary", - ) - else: - raise ValueError(f"Backend {backend} not found") - - def cal_qkv( - self, x, context=None, mask=None, rope_emb=None, **kwargs - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - del kwargs - - """ - self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers. - Before 07/24/2024, these modules normalize across all heads. - After 07/24/2024, to support tensor parallelism and follow the common practice in the community, - we support to normalize per head. - To keep the checkpoint copatibility with the previous code, - we keep the nn.Sequential but call the projection and the normalization layers separately. - We use a flag `self.qkv_norm_mode` to control the normalization behavior. - The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head. - """ - if self.qkv_norm_mode == "per_head": - q = self.to_q[0](x) - context = x if context is None else context - k = self.to_k[0](context) - v = self.to_v[0](context) - q, k, v = map( - lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads, c=self.dim_head), - (q, k, v), - ) - else: - raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") - - q = self.to_q[1](q) - k = self.to_k[1](k) - v = self.to_v[1](v) - if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) - k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) - return q, k, v - - def cal_attn(self, q, k, v, mask=None): - if self.backend == "transformer_engine": - seq_dim = self.qkv_format.index("s") - assert ( - q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 - ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." - out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] - return self.to_out(out) - else: - raise ValueError(f"Backend {self.backend} not found") - - def forward( - self, - x, - context=None, - mask=None, - rope_emb=None, - regional_contexts=None, - region_masks=None, - **kwargs, - ): - """ - Args: - x (Tensor): The query tensor of shape [B, Mq, K] - context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None - regional_contexts (Optional[Tensor]): Stacked regional context tensors [B, R, M, D] or [R, M, B, D] if THWBD format - region_masks (Optional[Tensor]): Region masks [B, R, S] or [R, S, B] if THWBD format - """ - q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - - # Early return if no regional contexts - if regional_contexts is None or region_masks is None: - return self.cal_attn(q, k, v, mask) - - # Process regional contexts - regional_k = [] - regional_v = [] - - # Determine format based on qkv_format - is_bshd = self.qkv_format == "bshd" - - # Get number of regions - num_regions = regional_contexts.shape[1] if is_bshd else regional_contexts.shape[0] - - # Process each region - for i in range(num_regions): - # Extract regional context - reg_context = regional_contexts[:, i] if is_bshd else regional_contexts[i] - - # Ensure correct dtype - if reg_context.dtype != context.dtype: - reg_context = reg_context.to(dtype=context.dtype) - - _, k_regional, v_regional = self.cal_qkv(x, reg_context, mask, rope_emb=rope_emb, **kwargs) - - regional_k.append(k_regional) - regional_v.append(v_regional) - - # Apply regional attention - combined_attn = self.regional_attn_op( - q, - k, # from global prompt - v, # from global prompt - regional_k=regional_k, - regional_v=regional_v, - region_masks=region_masks, - core_attention_bias_type="no_bias", - core_attention_bias=None, - ) - - # Apply output projection - return self.to_out(combined_attn) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/module/blocks.py b/cosmos-transfer1/cosmos_transfer1/diffusion/module/blocks.py deleted file mode 100644 index a63088504920918adf3b0d2f13d1c9972cd2ae35..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/module/blocks.py +++ /dev/null @@ -1,586 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional - -import numpy as np -import torch -from einops import rearrange, repeat -from einops.layers.torch import Rearrange -from torch import nn - -from cosmos_transfer1.diffusion.module.attention import Attention, GPT2FeedForward -from cosmos_transfer1.utils import log - - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - -class Timesteps(nn.Module): - def __init__(self, num_channels): - super().__init__() - self.num_channels = num_channels - - def forward(self, timesteps): - in_dype = timesteps.dtype - half_dim = self.num_channels // 2 - exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) - exponent = exponent / (half_dim - 0.0) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - sin_emb = torch.sin(emb) - cos_emb = torch.cos(emb) - emb = torch.cat([cos_emb, sin_emb], dim=-1) - - return emb.to(in_dype) - - -class TimestepEmbedding(nn.Module): - def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): - super().__init__() - log.debug( - f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." - ) - self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) - self.activation = nn.SiLU() - self.use_adaln_lora = use_adaln_lora - if use_adaln_lora: - self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) - else: - self.linear_2 = nn.Linear(out_features, out_features, bias=True) - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - emb = self.linear_1(sample) - emb = self.activation(emb) - emb = self.linear_2(emb) - - if self.use_adaln_lora: - adaln_lora_B_3D = emb - emb_B_D = sample - else: - emb_B_D = emb - adaln_lora_B_3D = None - - return emb_B_D, adaln_lora_B_3D - - -class FourierFeatures(nn.Module): - """ - Implements a layer that generates Fourier features from input tensors, based on randomly sampled - frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems. - - [B] -> [B, D] - - Parameters: - num_channels (int): The number of Fourier features to generate. - bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1. - normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize - the variance of the features. Defaults to False. - - Example: - >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True) - >>> x = torch.randn(10, 256) # Example input tensor - >>> output = layer(x) - >>> print(output.shape) # Expected shape: (10, 256) - """ - - def __init__(self, num_channels, bandwidth=1, normalize=False): - super().__init__() - self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) - self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) - self.gain = np.sqrt(2) if normalize else 1 - - def forward(self, x, gain: float = 1.0): - """ - Apply the Fourier feature transformation to the input tensor. - - Args: - x (torch.Tensor): The input tensor. - gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1. - - Returns: - torch.Tensor: The transformed tensor, with Fourier features applied. - """ - in_dtype = x.dtype - x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) - x = x.cos().mul(self.gain * gain).to(in_dtype) - return x - - -class PatchEmbed(nn.Module): - """ - PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, - depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, - making it suitable for video and image processing tasks. It supports dividing the input into patches - and embedding each patch into a vector of size `out_channels`. - - Parameters: - - spatial_patch_size (int): The size of each spatial patch. - - temporal_patch_size (int): The size of each temporal patch. - - in_channels (int): Number of input channels. Default: 3. - - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. - - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. - """ - - def __init__( - self, - spatial_patch_size, - temporal_patch_size, - in_channels=3, - out_channels=768, - bias=True, - ): - super().__init__() - self.spatial_patch_size = spatial_patch_size - self.temporal_patch_size = temporal_patch_size - - self.proj = nn.Sequential( - Rearrange( - "b c (t r) (h m) (w n) -> b t h w (c r m n)", - r=temporal_patch_size, - m=spatial_patch_size, - n=spatial_patch_size, - ), - nn.Linear( - in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias - ), - ) - self.out = nn.Identity() - - def forward(self, x): - """ - Forward pass of the PatchEmbed module. - - Parameters: - - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where - B is the batch size, - C is the number of channels, - T is the temporal dimension, - H is the height, and - W is the width of the input. - - Returns: - - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. - """ - assert x.dim() == 5 - _, _, T, H, W = x.shape - assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 - assert T % self.temporal_patch_size == 0 - x = self.proj(x) - return self.out(x) - - -class FinalLayer(nn.Module): - """ - The final layer of video DiT. - """ - - def __init__( - self, - hidden_size, - spatial_patch_size, - temporal_patch_size, - out_channels, - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - ): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear( - hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False - ) - self.hidden_size = hidden_size - self.n_adaln_chunks = 2 - self.use_adaln_lora = use_adaln_lora - if use_adaln_lora: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, adaln_lora_dim, bias=False), - nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), - ) - else: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) - ) - - def forward( - self, - x_BT_HW_D, - emb_B_D, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - ): - if self.use_adaln_lora: - assert adaln_lora_B_3D is not None - shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( - 2, dim=1 - ) - else: - shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) - - B = emb_B_D.shape[0] - T = x_BT_HW_D.shape[0] // B - shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) - x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) - - x_BT_HW_D = self.linear(x_BT_HW_D) - return x_BT_HW_D - - -class VideoAttn(nn.Module): - """ - Implements video attention with optional cross-attention capabilities. - - This module processes video features while maintaining their spatio-temporal structure. It can perform - self-attention within the video features or cross-attention with external context features. - - Parameters: - x_dim (int): Dimension of input feature vectors - context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention - num_heads (int): Number of attention heads - bias (bool): Whether to include bias in attention projections. Default: False - qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head" - x_format (str): Format of input tensor. Must be "BTHWD". Default: "THWBD" - n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of camera we model together. - Input shape: - - x: (T, H, W, B, D) video features - - context (optional): (M, B, D) context features for cross-attention - where: - T: temporal dimension - H: height - W: width - B: batch size - D: feature dimension - M: context sequence length - """ - - def __init__( - self, - x_dim: int, - context_dim: Optional[int], - num_heads: int, - bias: bool = False, - qkv_norm_mode: str = "per_head", - x_format: str = "THWBD", - n_views: int = 1, - ) -> None: - super().__init__() - self.n_views = n_views - self.x_format = x_format - - self.attn = Attention( - x_dim, - context_dim, - num_heads, - x_dim // num_heads, - qkv_bias=bias, - qkv_norm="RRI", - out_bias=bias, - qkv_norm_mode=qkv_norm_mode, - qkv_format="sbhd", - ) - - def forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass for video attention with regional prompting support. - - Args: - x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. - context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), - where M is the sequence length of the context. - crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. - rope_emb_L_1_1_D (Optional[Tensor]): - Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. - - Returns: - Tensor: The output tensor with applied attention, maintaining the input shape. - """ - if context is not None and self.n_views > 1: - x_T_H_W_B_D = rearrange(x, "(v t) h w b d -> t h w (v b) d", v=self.n_views) - context_M_B_D = rearrange(context, "(v m) b d -> m (v b) d", v=self.n_views) - else: - x_T_H_W_B_D = x - context_M_B_D = context - T, H, W, B, D = x_T_H_W_B_D.shape - x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") - if regional_contexts is not None: - regional_contexts = rearrange(regional_contexts, "r (v m) b d -> r m (v b) d", v=1) - if region_masks is not None: - r, t, h, w, b = region_masks.shape - region_masks = rearrange(region_masks, "r (v t) h w b -> r t h w (v b)", v=1) - x_THW_B_D = self.attn( - x_THW_B_D, - context_M_B_D, - crossattn_mask, - rope_emb=rope_emb_L_1_1_D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) - if context is not None and self.n_views > 1: - x_T_H_W_B_D = rearrange(x_T_H_W_B_D, "t h w (v b) d -> (v t) h w b d", v=self.n_views) - return x_T_H_W_B_D - - -def adaln_norm_state(norm_state, x, scale, shift): - normalized = norm_state(x) - return normalized * (1 + scale) + shift - - -class DITBuildingBlock(nn.Module): - """ - A building block for the DiT (Diffusion Transformer) architecture that supports different types of - attention and MLP operations with adaptive layer normalization. - - Parameters: - block_type (str): Type of block - one of: - - "cross_attn"/"ca": Cross-attention - - "full_attn"/"fa": Full self-attention - - "mlp"/"ff": MLP/feedforward block - x_dim (int): Dimension of input features - context_dim (Optional[int]): Dimension of context features for cross-attention - num_heads (int): Number of attention heads - mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 - bias (bool): Whether to use bias in layers. Default: False - mlp_dropout (float): Dropout rate for MLP. Default: 0.0 - qkv_norm_mode (str): QKV normalization mode. Default: "per_head" - x_format (str): Input tensor format. Default: "THWBD" - use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False - adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 - n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of camera we model together. - """ - - def __init__( - self, - block_type: str, - x_dim: int, - context_dim: Optional[int], - num_heads: int, - mlp_ratio: float = 4.0, - bias: bool = False, - mlp_dropout: float = 0.0, - qkv_norm_mode: str = "per_head", - x_format: str = "THWBD", - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - n_views: int = 1, - ) -> None: - block_type = block_type.lower() - - super().__init__() - self.x_format = x_format - if block_type in ["cross_attn", "ca"]: - self.block = VideoAttn( - x_dim, - context_dim, - num_heads, - bias=bias, - qkv_norm_mode=qkv_norm_mode, - x_format=self.x_format, - n_views=n_views, - ) - elif block_type in ["full_attn", "fa"]: - self.block = VideoAttn( - x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format - ) - elif block_type in ["mlp", "ff"]: - self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) - else: - raise ValueError(f"Unknown block type: {block_type}") - - self.block_type = block_type - self.use_adaln_lora = use_adaln_lora - - self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) - self.n_adaln_chunks = 3 - if use_adaln_lora: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(x_dim, adaln_lora_dim, bias=False), - nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), - ) - else: - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) - - def forward( - self, - x: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass for dynamically configured blocks with adaptive normalization. - - Args: - x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). Current only support (T, H, W, B, D). - emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. - crossattn_emb (Tensor): Tensor for cross-attention blocks. - crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. - rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. - adaln_lora_B_3D (Optional[Tensor]): Additional embedding for adaptive layer norm. - regional_contexts (Optional[List[Tensor]]): List of regional context tensors. - region_masks (Optional[Tensor]): Region masks of shape (B, R, THW). - - Returns: - Tensor: The output tensor after processing through the configured block and adaptive normalization. - """ - if self.use_adaln_lora: - shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( - self.n_adaln_chunks, dim=1 - ) - else: - shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) - - shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( - shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), - scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), - gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), - ) - - if self.block_type in ["mlp", "ff"]: - x = x + gate_1_1_1_B_D * self.block( - adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), - ) - elif self.block_type in ["full_attn", "fa"]: - x = x + gate_1_1_1_B_D * self.block( - adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), - context=None, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - ) - elif self.block_type in ["cross_attn", "ca"]: - normalized_x = adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D) - x = x + gate_1_1_1_B_D * self.block( - normalized_x, - context=crossattn_emb, - crossattn_mask=crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - else: - raise ValueError(f"Unknown block type: {self.block_type}") - - return x - - -class GeneralDITTransformerBlock(nn.Module): - """ - A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer. - Each block in the sequence is specified by a block configuration string. - - Parameters: - x_dim (int): Dimension of input features - context_dim (int): Dimension of context features for cross-attention blocks - num_heads (int): Number of attention heads - block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention, - full-attention, then MLP) - mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 - x_format (str): Input tensor format. Default: "THWBD" - use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False - adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 - - The block_config string uses "-" to separate block types: - - "ca"/"cross_attn": Cross-attention block - - "fa"/"full_attn": Full self-attention block - - "mlp"/"ff": MLP/feedforward block - - Example: - block_config = "ca-fa-mlp" creates a sequence of: - 1. Cross-attention block - 2. Full self-attention block - 3. MLP block - """ - - def __init__( - self, - x_dim: int, - context_dim: int, - num_heads: int, - block_config: str, - mlp_ratio: float = 4.0, - x_format: str = "THWBD", - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - n_views: int = 1, - ): - super().__init__() - self.blocks = nn.ModuleList() - self.x_format = x_format - for block_type in block_config.split("-"): - self.blocks.append( - DITBuildingBlock( - block_type, - x_dim, - context_dim, - num_heads, - mlp_ratio, - x_format=self.x_format, - use_adaln_lora=use_adaln_lora, - adaln_lora_dim=adaln_lora_dim, - n_views=n_views, - ) - ) - - def forward( - self, - x: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if extra_per_block_pos_emb is not None: - x = x + extra_per_block_pos_emb - for block in self.blocks: - x = block( - x, - emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - return x - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/module/parallel.py b/cosmos-transfer1/cosmos_transfer1/diffusion/module/parallel.py deleted file mode 100644 index e08356c9bff34bcfe989cd64614ce6ba0f506e85..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/module/parallel.py +++ /dev/null @@ -1,163 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from megatron.core import parallel_state -from torch import Tensor -from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size -from torch.distributed.utils import _verify_param_shape_across_processes - -from cosmos_transfer1.utils import distributed - - -def split_inputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: - """ - Split input tensor along the sequence dimension for checkpoint parallelism. - - This function divides the input tensor into equal parts along the specified - sequence dimension, based on the number of ranks in the checkpoint parallelism group. - It then selects the part corresponding to the current rank. - - Args: - x: Input tensor to be split. - seq_dim: The dimension along which to split the input (sequence dimension). - cp_group: The process group for checkpoint parallelism. - - Returns: - A slice of the input tensor corresponding to the current rank. - - Raises: - AssertionError: If the sequence dimension is not divisible by the number of ranks. - """ - cp_ranks = get_process_group_ranks(cp_group) - cp_size = len(cp_ranks) - - assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" - x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) - seq_idx = torch.tensor([cp_group.rank()], device=x.device) - x = x.index_select(seq_dim, seq_idx) - # Note that the new sequence length is the original sequence length / cp_size - x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) - return x - - -def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: - """ - Concatenate outputs from different ranks in the checkpoint parallelism group. - - This function gathers tensors from all ranks in the checkpoint parallelism group - and concatenates them along the specified sequence dimension. - - Args: - x: Input tensor to be concatenated. - seq_dim: The dimension along which to concatenate the tensors (sequence dimension). - cp_group: The process group for checkpoint parallelism. - - Returns: - A tensor that is the concatenation of tensors from all ranks in the cp_group. - - Raises: - RuntimeError: If the gather operation fails. - """ - # Get the world size (number of processes in the group) - world_size = get_world_size(cp_group) - - # Create a list to store tensors from all ranks - gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] - - # Gather tensors from all ranks - try: - all_gather(gathered_tensors, x, group=cp_group) - except RuntimeError as e: - raise RuntimeError(f"Failed to gather tensors: {e}") - - # Concatenate the gathered tensors along the specified dimension - return torch.cat(gathered_tensors, dim=seq_dim) - - -def broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: - """ - Broadcast the item from the minimum rank in the specified group(s). - Since global rank = tp_rank + cp_rank * tp_size + ... - First broadcast in the tp_group and then in the cp_group will - ensure that the item is broadcasted across ranks in cp_group and tp_group. - - Parameters: - - item: The item to broadcast (can be a torch.Tensor, str, or None). - - to_tp: Whether to broadcast to the tensor model parallel group. - - to_cp: Whether to broadcast to the context parallel group. - """ - if not parallel_state.is_initialized(): - return item - tp_group = parallel_state.get_tensor_model_parallel_group() - cp_group = parallel_state.get_context_parallel_group() - - to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 - to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 - - if to_tp: - min_tp_rank = min(get_process_group_ranks(tp_group)) - - if to_cp: - min_cp_rank = min(get_process_group_ranks(cp_group)) - - if isinstance(item, torch.Tensor): # assume the device is cuda - # log.info(f"{item.shape}", rank0_only=False) - if to_tp: - # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) - item = _robust_broadcast(item, min_tp_rank, tp_group) - if to_cp: - # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) - item = _robust_broadcast(item, min_cp_rank, cp_group) - elif item is not None: - broadcastable_list = [item] - if to_tp: - # log.info(f"{broadcastable_list}", rank0_only=False) - broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) - if to_cp: - broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) - - item = broadcastable_list[0] - return item - - -def _robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: - """ - Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. - - Args: - tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). - src (int): The source rank for the broadcast. Defaults to 0. - - Returns: - torch.Tensor: The broadcasted tensor on all ranks. - """ - # First, broadcast the shape of the tensor - if distributed.get_rank() == src: - shape = torch.tensor(tensor.shape).cuda() - else: - shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() - if is_check_shape: - _verify_param_shape_across_processes(pg, [shape]) - torch.distributed.broadcast(shape, src, group=pg) - - # Resize the tensor on non-src ranks if necessary - if distributed.get_rank() != src: - tensor = tensor.new_empty(shape.tolist()).type_as(tensor) - - # Now broadcast the tensor data - torch.distributed.broadcast(tensor, src, group=pg) - - return tensor diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/module/position_embedding.py b/cosmos-transfer1/cosmos_transfer1/diffusion/module/position_embedding.py deleted file mode 100644 index 236aee03e7dd74929e4f7ca8ddba5f10667738d9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/module/position_embedding.py +++ /dev/null @@ -1,905 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Literal, Optional - -import numpy as np -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import nn -from torch.distributed import ProcessGroup, get_process_group_ranks - -from cosmos_transfer1.diffusion.module.attention import normalize -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.module.timm import trunc_normal_ - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -def get_3d_sincos_pos_embed( - embed_dim, - grid_size_h, - grid_size_w, - grid_size_t, - spatial_interpolation_scale, - temporal_interpolation_scale, - concat=True, -): - grid_h = np.arange(grid_size_h, dtype=np.float32) / spatial_interpolation_scale - grid_w = np.arange(grid_size_w, dtype=np.float32) / spatial_interpolation_scale - grid_t = np.arange(grid_size_t, dtype=np.float32) / temporal_interpolation_scale - - grid = np.meshgrid(grid_w, grid_h, grid_t, indexing="ij") - grid = np.stack(grid, axis=0) - grid = grid.reshape(3, 1, grid_size_h, grid_size_w, grid_size_t) - - if concat: - per_axis = embed_dim // 3 - per_axis = (per_axis // 2) * 2 # make it even (for sin/cos split) - dim_h, dim_w = per_axis, per_axis - dim_t = embed_dim - dim_h - dim_w - emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, grid[0]) # (H*W, D/3) - emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, grid[1]) # (H*W, D/3) - emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, grid[2]) # (H*W, D/3) - - return np.concatenate([emb_h, emb_w, emb_t], axis=1) # (H*W*T, D) - else: - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[0]) # (H*W) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[1]) # (H*W) - emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid[2]) # (H*W) - - return emb_h + emb_w + emb_t # (H*W*T, D) - - -class VideoPositionEmb(nn.Module): - def __init__(self): - super().__init__() - self.cp_group = None - - def enable_context_parallel(self, cp_group: ProcessGroup): - self.cp_group = cp_group - - def disable_context_parallel(self): - self.cp_group = None - - def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: - """ - It delegates the embedding generation to generate_embeddings function. - """ - B_T_H_W_C = x_B_T_H_W_C.shape - if self.cp_group is not None: - cp_ranks = get_process_group_ranks(self.cp_group) - cp_size = len(cp_ranks) - B, T, H, W, C = B_T_H_W_C - B_T_H_W_C = (B, T * cp_size, H, W, C) - embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) - - if self.cp_group is not None: - if isinstance(self, VideoRopePosition3DEmb): - seq_dim = 0 - else: - seq_dim = 1 - embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) - return embeddings - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): - raise NotImplementedError - - -class VideoRopePositionEmb(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - head_dim: int, - len_h: int, - len_w: int, - len_t: int, - **kwargs, # used for compatibility with other positional embeddings; unused in this class - ): - del kwargs - super().__init__() - self.register_buffer("seq", torch.arange(len_h * len_w * len_t, dtype=torch.float)) - - self.register_buffer( - "dim_range", torch.arange(0, head_dim, 2)[: (head_dim // 2)].float().cuda() / head_dim, persistent=False - ) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], ntk_factor: float = 1.0): - theta = 10000.0 * ntk_factor - - # original_dtype = self.dim_range.dtype - freq = 1.0 / (theta ** self.dim_range.float()) - _, T, H, W, _ = B_T_H_W_C - length = T * H * W - emb_L_D = torch.outer(self.seq[:length], freq) - return rearrange(torch.cat([emb_L_D, emb_L_D], dim=-1), "l d -> l 1 1 d").float() - - -class VideoRopePosition3DEmb(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - head_dim: int, - len_h: int, - len_w: int, - len_t: int, - base_fps: int = 24, - h_extrapolation_ratio: float = 1.0, - w_extrapolation_ratio: float = 1.0, - t_extrapolation_ratio: float = 1.0, - **kwargs, # used for compatibility with other positional embeddings; unused in this class - ): - del kwargs - super().__init__() - self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) - self.base_fps = base_fps - self.max_h = len_h - self.max_w = len_w - - dim = head_dim - dim_h = dim // 6 * 2 - dim_w = dim_h - dim_t = dim - 2 * dim_h - assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" - self.register_buffer( - "dim_spatial_range", - torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, - persistent=False, - ) - self.register_buffer( - "dim_temporal_range", - torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, - persistent=False, - ) - - self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) - self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) - self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) - - def generate_embeddings( - self, - B_T_H_W_C: torch.Size, - fps: Optional[torch.Tensor] = None, - h_ntk_factor: Optional[float] = None, - w_ntk_factor: Optional[float] = None, - t_ntk_factor: Optional[float] = None, - ): - """ - Generate embeddings for the given input size. - - Args: - B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). - fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. - h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. - w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. - t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. - - Returns: - Not specified in the original code snippet. - """ - h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor - w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor - t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor - - h_theta = 10000.0 * h_ntk_factor - w_theta = 10000.0 * w_ntk_factor - t_theta = 10000.0 * t_ntk_factor - - h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) - w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) - temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) - - B, T, H, W, _ = B_T_H_W_C - uniform_fps = (fps is None) or (fps.min() == fps.max()) - assert ( - uniform_fps or B == 1 or T == 1 - ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" - assert ( - H <= self.max_h and W <= self.max_w - ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" - half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) - half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) - - # apply sequence scaling in temporal dimension - if fps is None: # image case - assert T == 1, "T should be 1 for image batch." - half_emb_t = torch.outer(self.seq[:T], temporal_freqs) - else: - half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) - - em_T_H_W_D = torch.cat( - [ - repeat(half_emb_t, "t d -> t h w d", h=H, w=W), - repeat(half_emb_h, "h d -> t h w d", t=T, w=W), - repeat(half_emb_w, "w d -> t h w d", t=T, h=H), - ] - * 2, - dim=-1, - ) - - return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() - - -class LearnablePosEmbAxis(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - interpolation: str, - model_channels: int, - len_h: int, - len_w: int, - len_t: int, - **kwargs, - ): - """ - Args: - interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. - """ - del kwargs # unused - super().__init__() - self.interpolation = interpolation - assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" - - self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) - self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) - self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) - - trunc_normal_(self.pos_emb_h, std=0.02) - trunc_normal_(self.pos_emb_w, std=0.02) - trunc_normal_(self.pos_emb_t, std=0.02) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: - B, T, H, W, _ = B_T_H_W_C - if self.interpolation == "crop": - emb_h_H = self.pos_emb_h[:H] - emb_w_W = self.pos_emb_w[:W] - emb_t_T = self.pos_emb_t[:T] - emb = ( - repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) - + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) - + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) - ) - assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" - else: - raise ValueError(f"Unknown interpolation method {self.interpolation}") - - return normalize(emb, dim=-1, eps=1e-6) - - -class LearnableEmb3D(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - model_channels: int, - len_h: int, - len_w: int, - len_t: int, - interpolation: str = "crop", - is_learnable: bool = True, - **kwargs, # used for compatibility with other positional embeddings; unused in this class - ): - del kwargs # unused - super().__init__() - assert is_learnable is True - self.interpolation = interpolation - self.pos_embed = nn.Parameter(torch.zeros(1, len_t, len_h, len_w, model_channels)) - trunc_normal_(self.pos_embed, std=0.02) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: - B, T, H, W, C = B_T_H_W_C - if self.interpolation == "crop": - return self.pos_embed[:, :T, :H, :W] - if self.interpolation == "resize": - return rearrange( - F.interpolate( - rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), - size=(H, W, T), - mode="linear", - align_corners=False, - ), - "1 c h w t -> 1 t h w c", - ) - raise ValueError(f"Unknown interpolation method {self.interpolation}") - - -class LearnableEmb3D_FPS_Aware(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - model_channels: int, - len_h: int, - len_w: int, - len_t: int, - min_fps: int, # 1 for getty video - max_fps: int, # 120 for getty video - interpolation: str = "crop", - is_learnable: bool = True, - **kwargs, # used for compatibility with other positional embeddings; unused in this class - ): - del kwargs - super().__init__() - assert is_learnable is True - self.interpolation = interpolation - self.max_fps = max_fps - self.min_fps = min_fps - - if self.interpolation == "crop": - self.pos_embed = nn.Parameter( - torch.zeros(1, len_t * int(max_fps / min_fps), len_h, len_w, model_channels) - ) # should be max_seq_length * (max_fps / min_fps) - elif self.interpolation == "resize": - self.pos_embed = nn.Parameter( - torch.zeros(1, len_t, len_h, len_w, model_channels) - ) # time embedding based min fps - else: - ValueError(f"Unknown interpolation method {self.interpolation}") - - trunc_normal_(self.pos_embed, std=0.02) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: - B, T, H, W, C = B_T_H_W_C - - if self.interpolation == "crop": - if T > 1: - return torch.cat( - [ - self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] - for curr_fps in fps - ], - 0, - ) - else: - return self.pos_embed[:, :T, :H, :W] # image model - elif self.interpolation == "resize": - if T > 1: - return torch.cat( - [ - rearrange( - F.interpolate( - rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), - size=(H, W, T * int(curr_fps / self.min_fps)), - mode="trilinear", - align_corners=True, # important: align corner need to be true - )[:, :, :H, :W, :T], - "1 c h w t -> 1 t h w c", - ) - for curr_fps in fps - ], - 0, - ) - else: - # grab self.pos_embed at time step 0 and resize spatially - return rearrange( - F.interpolate( - rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), - size=(H, W), - mode="bilinear", - align_corners=True, - ), - "1 c h w -> 1 h w c", - ) - raise ValueError(f"Unknown interpolation method {self.interpolation}") - - -class SinCosPosEmbAxis(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - interpolation: str, - model_channels: int, - len_h: int, - len_w: int, - len_t: int, - h_extrapolation_ratio: float = 1.0, - w_extrapolation_ratio: float = 1.0, - t_extrapolation_ratio: float = 1.0, - **kwargs, - ): - """ - Args: - interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. - """ - del kwargs # unused - super().__init__() - self.interpolation = interpolation - assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" - - dim = model_channels - dim_h = dim // 6 * 2 - dim_w = dim_h - dim_t = dim - 2 * dim_h - assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" - - # rescale pos id is equivalent to rescale frequency - emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) - emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) - emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) - - self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) - self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) - self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: - B, T, H, W, C = B_T_H_W_C - if self.interpolation == "crop": - emb_h_H = self.pos_emb_h[:H] - emb_w_W = self.pos_emb_w[:W] - emb_t_T = self.pos_emb_t[:T] - emb = torch.cat( - [ - repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), - repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W), - repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H), - ], - dim=-1, - ) - assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" - return emb - - raise ValueError(f"Unknown interpolation method {self.interpolation}") - - -class SinCosPosEmb_FPS_Aware(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - model_channels: int, - len_h: int, - len_w: int, - len_t: int, - min_fps: int, # 1 for getty video - max_fps: int, # 120 for getty video - is_learnable: bool = False, - interpolation: str = "crop", - spatial_interpolation_scale=1.0, - temporal_interpolation_scale=1.0, - **kwargs, # used for compatibility with other positional embeddings; unused in this class - ): - del kwargs # unused - super().__init__() - self.interpolation = interpolation - self.max_fps = max_fps - self.min_fps = min_fps - if self.interpolation == "crop": - param = get_3d_sincos_pos_embed( - model_channels, - len_h, - len_w, - len_t * int(max_fps / min_fps), - spatial_interpolation_scale, - temporal_interpolation_scale, - ) # should be max_seq_length * (max_fps / min_fps) - elif self.interpolation == "resize": - param = get_3d_sincos_pos_embed( - model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale - ) # time embedding based min fps - else: - ValueError(f"Unknown interpolation method {self.interpolation}") - param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) - if is_learnable: - self.pos_embed = nn.Parameter( - torch.from_numpy(param).float(), - ) - else: - self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: - B, T, H, W, C = B_T_H_W_C - - if self.interpolation == "crop": - if T > 1: - return torch.cat( - [ - self.pos_embed[:, : (int(self.max_fps / curr_fps) * T) : int(self.max_fps / curr_fps), :H, :W] - for curr_fps in fps - ], - 0, - ) - else: - return self.pos_embed[:, :T, :H, :W] # image model - elif self.interpolation == "resize": - if T > 1: - return torch.cat( - [ - rearrange( - F.interpolate( - rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), - size=(H, W, T * int(curr_fps / self.min_fps)), - mode="trilinear", - align_corners=True, # important: align corner need to be true - )[:, :, :H, :W, :T], - "1 c h w t -> 1 t h w c", - ) - for curr_fps in fps - ], - 0, - ) - else: - # grab self.pos_embed at time step 0 and resize spatially - return rearrange( - F.interpolate( - rearrange(self.pos_embed[:, 0, ::], "1 h w c -> 1 c h w"), - size=(H, W), - mode="bilinear", - align_corners=True, - ), - "1 c h w -> 1 h w c", - ) - raise ValueError(f"Unknown interpolation method {self.interpolation}") - - -class SinCosPosEmb(VideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - model_channels: int, - len_h: int, - len_w: int, - len_t: int, - is_learnable: bool = False, - interpolation: Literal["crop", "resize", "crop_resize"] = "crop", - spatial_interpolation_scale=1.0, - temporal_interpolation_scale=1.0, - init_length_for_resize: int = 16, - **kwargs, - ): - """ - Args: - interpolation (str): "crop", "resize", "crop_resize". "crop" means we crop the positional embedding to the length of the input sequence. "resize" means we resize the positional embedding to the length of the input sequence. "crop_resize" (inference only) means we first crop the positional embedding to init_length_for_resize, then resize it to the length of the input sequence. - init_length_for_resize (int): used when interpolation is "crop_resize", where we "resize" embedding during inference for model trained with "crop". We first "crop" the pos_embed to this length (used during training), then run the "resize", default 16 - """ - del kwargs # unused - super().__init__() - self.interpolation = interpolation - self.init_length_for_resize = init_length_for_resize - param = get_3d_sincos_pos_embed( - model_channels, len_h, len_w, len_t, spatial_interpolation_scale, temporal_interpolation_scale - ) - param = rearrange(param, "(h w t) c -> 1 t h w c", h=len_h, w=len_w) - if is_learnable: - self.pos_embed = nn.Parameter( - torch.from_numpy(param).float(), - ) - else: - self.register_buffer("pos_embed", torch.from_numpy(param).float(), persistent=False) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: - B, T, H, W, C = B_T_H_W_C - if self.interpolation == "crop": - return self.pos_embed[:, :T, :H, :W] - if self.interpolation == "resize": - return rearrange( - F.interpolate( - rearrange(self.pos_embed, "1 t h w c -> 1 c h w t"), - size=(H, W, T), - mode="linear", - align_corners=False, - ), - "1 c h w t -> 1 t h w c", - ) - if self.interpolation == "crop_resize": - pos_embed_crop = self.pos_embed[:, : self.init_length_for_resize, :H, :W] # B,T,H,W,C - _, t, h, w, c = pos_embed_crop.shape - - pos_embed_crop_resize_t = rearrange( - F.interpolate( - rearrange(pos_embed_crop, "1 t h w c -> 1 (c h w) t"), - size=(T), - mode="linear", - ), - "1 (c h w) t -> 1 t h w c", - c=c, - h=h, - w=w, - ) - pos_embed_crop_resize = rearrange( - F.interpolate( - rearrange(pos_embed_crop_resize_t, "1 t h w c -> 1 (c t) h w"), - size=(H, W), - mode="bilinear", - ), - "1 (c t) h w -> 1 t h w c", - c=c, - ) - return pos_embed_crop_resize - - raise ValueError(f"Unknown interpolation method {self.interpolation}") - - -class MultiCameraVideoPositionEmb(nn.Module): - def __init__( - self, - ): - super().__init__() - self.cp_group = None - - def enable_context_parallel(self, cp_group: ProcessGroup): - self.cp_group = cp_group - - def disable_context_parallel(self): - self.cp_group = None - - def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: - """ - With CP, the function assume that the input tensor is already split. It delegates the embedding generation to generate_embeddings function. - """ - B_T_H_W_C = x_B_T_H_W_C.shape - if self.cp_group is not None: - cp_ranks = get_process_group_ranks(self.cp_group) - cp_size = len(cp_ranks) - B, T, H, W, C = B_T_H_W_C - B_T_H_W_C = (B, T * cp_size, H, W, C) - embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) - - if self.cp_group is not None: - if isinstance(self, MultiCameraVideoRopePosition3DEmb): - seq_dim = 1 - embeddings = rearrange(embeddings, "(V T) H W D -> V (T H W) 1 1 D", V=self.n_views).float() - # rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() - embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) - embeddings = rearrange(embeddings, "V T 1 1 D -> (V T) 1 1 D", V=self.n_views).float() - else: - seq_dim = 1 - embeddings = rearrange(embeddings, "B (V T) H W C -> (B V) T H W C", V=self.n_views) - embeddings = split_inputs_cp(x=embeddings, seq_dim=seq_dim, cp_group=self.cp_group) - embeddings = rearrange(embeddings, "(B V) T H W C -> B (V T) H W C", V=self.n_views) - else: - if isinstance(self, MultiCameraVideoRopePosition3DEmb): - embeddings = rearrange(embeddings, "t h w d -> (t h w) 1 1 d").float() - - return embeddings - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): - raise NotImplementedError - - -class MultiCameraVideoRopePosition3DEmb(MultiCameraVideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - head_dim: int, - len_h: int, - len_w: int, - len_t: int, - base_fps: int = 24, - h_extrapolation_ratio: float = 1.0, - w_extrapolation_ratio: float = 1.0, - t_extrapolation_ratio: float = 1.0, - n_views: int = 4, - **kwargs, # used for compatibility with other positional embeddings; unused in this class - ): - del kwargs - super().__init__() - self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) - self.base_fps = base_fps - self.max_h = len_h - self.max_w = len_w - self.n_views = n_views - dim = head_dim - dim_h = dim // 6 * 2 - dim_w = dim_h - dim_t = dim - 2 * dim_h - assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" - self.register_buffer( - "dim_spatial_range", - torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, - persistent=False, - ) - self.register_buffer( - "dim_temporal_range", - torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, - persistent=False, - ) - - self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) - self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) - self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) - - def generate_embedding_for_batch( - self, - B_T_H_W_C: torch.Size, - fps: Optional[torch.Tensor] = None, - h_ntk_factor: Optional[float] = None, - w_ntk_factor: Optional[float] = None, - t_ntk_factor: Optional[float] = None, - ): - """ - Generate embeddings for the given input size. - - Args: - B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). - fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. - h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. - w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. - t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. - - Returns: - Not specified in the original code snippet. - """ - h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor - w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor - t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor - - h_theta = 10000.0 * h_ntk_factor - w_theta = 10000.0 * w_ntk_factor - t_theta = 10000.0 * t_ntk_factor - - h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) - w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) - temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) - - B, T, H, W, _ = B_T_H_W_C - uniform_fps = (fps is None) or (fps.min() == fps.max()) - assert uniform_fps # only support uniform fps now - - assert ( - uniform_fps or B == 1 or T == 1 - ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" - assert ( - H <= self.max_h and W <= self.max_w - ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w}) configured for positional embedding. Please adjust the input size or increase the maximum dimensions in the model configuration." - half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) - half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) - - # apply sequence scaling in temporal dimension - if fps is None: # image case - assert T == 1, "T should be 1 for image batch." - half_emb_t = torch.outer(self.seq[:T], temporal_freqs) - else: - half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) - - em_T_H_W_D = torch.cat( - [ - repeat(half_emb_t, "t d -> t h w d", h=H, w=W), - repeat(half_emb_h, "h d -> t h w d", t=T, w=W), - repeat(half_emb_w, "w d -> t h w d", t=T, h=H), - ] - * 2, - dim=-1, - ) - - return em_T_H_W_D - - def generate_embeddings( - self, - B_T_H_W_C: torch.Size, - fps: Optional[torch.Tensor] = None, - h_ntk_factor: Optional[float] = None, - w_ntk_factor: Optional[float] = None, - t_ntk_factor: Optional[float] = None, - ): - """ - Generate embeddings for the given input size. The camera view dimension is merged in the T dimension - - Args: - B_T_H_W_C (torch.Size): Input tensor size (Batch, Time * Views, Height, Width, Channels). - fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. - h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. Defaults to None. - w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. Defaults to None. - t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. Defaults to None. - - Returns: - Not specified in the original code snippet. - """ - - B, T, H, W, C = B_T_H_W_C - - single_camera_B_T_H_W_C = (B, T // self.n_views, H, W, C) - em_T_H_W_D = torch.cat( - [ - self.generate_embedding_for_batch( - single_camera_B_T_H_W_C, - fps=fps, - h_ntk_factor=h_ntk_factor, - w_ntk_factor=w_ntk_factor, - t_ntk_factor=t_ntk_factor, - ) - for item in range(self.n_views) - ], - dim=0, - ) - - return em_T_H_W_D - # return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() - - -class MultiCameraSinCosPosEmbAxis(MultiCameraVideoPositionEmb): - def __init__( - self, - *, # enforce keyword arguments - interpolation: str, - model_channels: int, - len_h: int, - len_w: int, - len_t: int, - h_extrapolation_ratio: float = 1.0, - w_extrapolation_ratio: float = 1.0, - t_extrapolation_ratio: float = 1.0, - n_views: int = 4, - **kwargs, - ): - # TODO: (qsh 2024-11-08) add more interpolation methods and args for extrapolation fine-tuning - """ - Args: - interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. - """ - del kwargs # unused - self.n_views = n_views - super().__init__() - self.interpolation = interpolation - assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" - - dim = model_channels - dim_h = dim // 6 * 2 - dim_w = dim_h - dim_t = dim - 2 * dim_h - assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" - - # rescale pos id is equivalent to rescale frequency - emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(len_h) * 1.0 / h_extrapolation_ratio) - emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(len_w) * 1.0 / w_extrapolation_ratio) - emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(len_t) * 1.0 / t_extrapolation_ratio) - - self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).float(), persistent=False) - self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).float(), persistent=False) - self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).float(), persistent=False) - - def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: - B, T, H, W, C = B_T_H_W_C - - single_camera_T = T // self.n_views - - if self.interpolation == "crop": - emb_h_H = self.pos_emb_h[:H] - emb_w_W = self.pos_emb_w[:W] - emb_t_T = self.pos_emb_t[:single_camera_T] - emb = torch.cat( - [ - torch.cat( - [ - repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W), - repeat(emb_h_H, "h d-> b t h w d", b=B, t=single_camera_T, w=W), - repeat(emb_w_W, "w d-> b t h w d", b=B, t=single_camera_T, h=H), - ], - dim=-1, - ) - for _ in range(self.n_views) - ], - 1, - ) - assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" - return emb - - raise ValueError(f"Unknown interpolation method {self.interpolation}") diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/module/pretrained_vae.py b/cosmos-transfer1/cosmos_transfer1/diffusion/module/pretrained_vae.py deleted file mode 100644 index 5698284b6c9922bad96d82337062e08b1ad61de2..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/module/pretrained_vae.py +++ /dev/null @@ -1,605 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from abc import ABC, abstractmethod - -import torch -from einops import rearrange -from torch.nn.modules import Module - - -class BaseVAE(torch.nn.Module, ABC): - """ - Abstract base class for a Variational Autoencoder (VAE). - - All subclasses should implement the methods to define the behavior for encoding - and decoding, along with specifying the latent channel size. - """ - - def __init__(self, channel: int = 3, name: str = "vae"): - super().__init__() - self.channel = channel - self.name = name - - @property - def latent_ch(self) -> int: - """ - Returns the number of latent channels in the VAE. - """ - return self.channel - - @abstractmethod - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - Encodes the input tensor into a latent representation. - - Args: - - state (torch.Tensor): The input tensor to encode. - - Returns: - - torch.Tensor: The encoded latent tensor. - """ - pass - - @abstractmethod - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decodes the latent representation back to the original space. - - Args: - - latent (torch.Tensor): The latent tensor to decode. - - Returns: - - torch.Tensor: The decoded tensor. - """ - pass - - @property - def spatial_compression_factor(self) -> int: - """ - Returns the spatial reduction factor for the VAE. - """ - raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") - - -class BasePretrainedImageVAE(BaseVAE): - """ - A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values - from a remote store, handles data type conversions, and normalization - using provided mean and standard deviation values for latent space representation. - Derived classes should load pre-trained encoder and decoder components from a remote store - - Attributes: - latent_mean (Tensor): The mean used for normalizing the latent representation. - latent_std (Tensor): The standard deviation used for normalizing the latent representation. - dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - - Args: - mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. - latent_ch (int, optional): Number of latent channels (default is 16). - is_image (bool, optional): Flag to indicate whether the output is an image (default is True). - is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). - """ - - def __init__( - self, - name: str, - latent_ch: int = 16, - is_image: bool = True, - is_bf16: bool = True, - ) -> None: - super().__init__(latent_ch, name) - dtype = torch.bfloat16 if is_bf16 else torch.float32 - self.dtype = dtype - self.is_image = is_image - self.name = name - - def register_mean_std(self, vae_dir: str) -> None: - latent_mean, latent_std = torch.load(os.path.join(vae_dir, "image_mean_std.pt"), weights_only=False) - - target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] - - self.register_buffer( - "latent_mean", - latent_mean.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - self.register_buffer( - "latent_std", - latent_std.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - Encode the input state to latent space; also handle the dtype conversion, mean and std scaling - """ - in_dtype = state.dtype - latent_mean = self.latent_mean.to(in_dtype) - latent_std = self.latent_std.to(in_dtype) - encoded_state = self.encoder(state.to(self.dtype)) - if isinstance(encoded_state, torch.Tensor): - pass - elif isinstance(encoded_state, tuple): - assert isinstance(encoded_state[0], torch.Tensor) - encoded_state = encoded_state[0] - else: - raise ValueError("Invalid type of encoded state") - return (encoded_state.to(in_dtype) - latent_mean) / latent_std - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decode the input latent to state; also handle the dtype conversion, mean and std scaling - """ - in_dtype = latent.dtype - latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) - return self.decoder(latent.to(self.dtype)).to(in_dtype) - - def reset_dtype(self, *args, **kwargs): - """ - Resets the data type of the encoder and decoder to the model's default data type. - - Args: - *args, **kwargs: Unused, present to allow flexibility in method calls. - """ - del args, kwargs - self.decoder.to(self.dtype) - self.encoder.to(self.dtype) - - -class JITVAE(BasePretrainedImageVAE): - """ - A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder - and decoder components from a remote store, handles data type conversions, and normalization - using provided mean and standard deviation values for latent space representation. - - Attributes: - encoder (Module): The JIT compiled encoder loaded from storage. - decoder (Module): The JIT compiled decoder loaded from storage. - latent_mean (Tensor): The mean used for normalizing the latent representation. - latent_std (Tensor): The standard deviation used for normalizing the latent representation. - dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - - Args: - name (str): Name of the model, used for differentiating cache file paths. - latent_ch (int, optional): Number of latent channels (default is 16). - is_image (bool, optional): Flag to indicate whether the output is an image (default is True). - is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). - """ - - def __init__( - self, - name: str, - latent_ch: int = 16, - is_image: bool = True, - is_bf16: bool = True, - ): - super().__init__(name, latent_ch, is_image, is_bf16) - - def load_encoder(self, vae_dir: str) -> None: - """ - Load the encoder from the remote store. - """ - self.encoder = torch.load(os.path.join(vae_dir, "encoder.jit"), weights_only=False) - - self.encoder.eval() - for param in self.encoder.parameters(): - param.requires_grad = False - self.encoder.to(self.dtype) - - def load_decoder(self, vae_dir: str) -> None: - """ - Load the decoder from the remote store. - """ - self.decoder = torch.load(os.path.join(vae_dir, "decoder.jit"), weights_only=False) - - self.decoder.eval() - for param in self.decoder.parameters(): - param.requires_grad = False - self.decoder.to(self.dtype) - - -class BaseVAE(torch.nn.Module, ABC): - """ - Abstract base class for a Variational Autoencoder (VAE). - - All subclasses should implement the methods to define the behavior for encoding - and decoding, along with specifying the latent channel size. - """ - - def __init__(self, channel: int = 3, name: str = "vae"): - super().__init__() - self.channel = channel - self.name = name - - @property - def latent_ch(self) -> int: - """ - Returns the number of latent channels in the VAE. - """ - return self.channel - - @abstractmethod - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - Encodes the input tensor into a latent representation. - - Args: - - state (torch.Tensor): The input tensor to encode. - - Returns: - - torch.Tensor: The encoded latent tensor. - """ - pass - - @abstractmethod - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decodes the latent representation back to the original space. - - Args: - - latent (torch.Tensor): The latent tensor to decode. - - Returns: - - torch.Tensor: The decoded tensor. - """ - pass - - @property - def spatial_compression_factor(self) -> int: - """ - Returns the spatial reduction factor for the VAE. - """ - raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") - - -class VideoTokenizerInterface(ABC): - @abstractmethod - def encode(self, state: torch.Tensor) -> torch.Tensor: - pass - - @abstractmethod - def decode(self, latent: torch.Tensor) -> torch.Tensor: - pass - - @abstractmethod - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - pass - - @abstractmethod - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - pass - - @property - @abstractmethod - def spatial_compression_factor(self): - pass - - @property - @abstractmethod - def temporal_compression_factor(self): - pass - - @property - @abstractmethod - def spatial_resolution(self): - pass - - @property - @abstractmethod - def pixel_chunk_duration(self): - pass - - @property - @abstractmethod - def latent_chunk_duration(self): - pass - - -class BasePretrainedVideoTokenizer(ABC): - """ - Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. - - Args: - pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. - temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. - max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. - max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. - - The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) - which define how video data is subdivided and compressed during the encoding and decoding processes. The - `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory - constraints. - """ - - def __init__( - self, - pixel_chunk_duration: int = 17, - temporal_compress_factor: int = 8, - max_enc_batch_size: int = 8, - max_dec_batch_size: int = 4, - ): - self._pixel_chunk_duration = pixel_chunk_duration - self._temporal_compress_factor = temporal_compress_factor - self.max_enc_batch_size = max_enc_batch_size - self.max_dec_batch_size = max_dec_batch_size - - def register_mean_std(self, vae_dir: str) -> None: - latent_mean, latent_std = torch.load(os.path.join(vae_dir, "mean_std.pt"), weights_only=False) - - latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] - latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] - - target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] - - self.register_buffer( - "latent_mean", - latent_mean.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - self.register_buffer( - "latent_std", - latent_std.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - - def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: - """ - Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding - """ - B, C, T, H, W = state.shape - assert ( - T % self.pixel_chunk_duration == 0 - ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" - return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) - - def transform_decode_state_shape(self, latent: torch.Tensor) -> torch.Tensor: - B, _, T, _, _ = latent.shape - assert ( - T % self.latent_chunk_duration == 0 - ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" - return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - if self._temporal_compress_factor == 1: - _, _, origin_T, _, _ = state.shape - state = rearrange(state, "b c t h w -> (b t) c 1 h w") - B, C, T, H, W = state.shape - state = self.transform_encode_state_shape(state) - # use max_enc_batch_size to avoid OOM - if state.shape[0] > self.max_enc_batch_size: - latent = [] - for i in range(0, state.shape[0], self.max_enc_batch_size): - latent.append(super().encode(state[i : i + self.max_enc_batch_size])) - latent = torch.cat(latent, dim=0) - else: - latent = super().encode(state) - - latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) - if self._temporal_compress_factor == 1: - latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) - return latent - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, - it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. - - It can also decode single frame image data. - - Args: - latent (torch.Tensor): The latent space tensor containing encoded video data. - - Returns: - torch.Tensor: The decoded video tensor reconstructed from latent space. - """ - if self._temporal_compress_factor == 1: - _, _, origin_T, _, _ = latent.shape - latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") - B, _, T, _, _ = latent.shape - latent = self.transform_decode_state_shape(latent) - # use max_enc_batch_size to avoid OOM - if latent.shape[0] > self.max_dec_batch_size: - state = [] - for i in range(0, latent.shape[0], self.max_dec_batch_size): - state.append(super().decode(latent[i : i + self.max_dec_batch_size])) - state = torch.cat(state, dim=0) - else: - state = super().decode(latent) - assert state.shape[2] == self.pixel_chunk_duration - state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) - if self._temporal_compress_factor == 1: - return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) - return state - - @property - def pixel_chunk_duration(self) -> int: - return self._pixel_chunk_duration - - @property - def latent_chunk_duration(self) -> int: - # return self._latent_chunk_duration - assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( - f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " - f"{self.latent_chunk_duration}" - ) - return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 - - @property - def temporal_compression_factor(self): - return self._temporal_compress_factor - - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - if num_pixel_frames == 1: - return 1 - assert ( - num_pixel_frames % self.pixel_chunk_duration == 0 - ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" - return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration - - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - if num_latent_frames == 1: - return 1 - assert ( - num_latent_frames % self.latent_chunk_duration == 0 - ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" - return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration - - -class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): - """ - Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file - """ - - def __init__( - self, - name: str, - latent_ch: int = 16, - is_bf16: bool = True, - spatial_compression_factor: int = 16, - temporal_compression_factor: int = 8, - pixel_chunk_duration: int = 17, - max_enc_batch_size: int = 8, - max_dec_batch_size: int = 4, - spatial_resolution: str = "720", - ): - super().__init__( - pixel_chunk_duration, - temporal_compression_factor, - max_enc_batch_size, - max_dec_batch_size, - ) - super(BasePretrainedVideoTokenizer, self).__init__( - name, - latent_ch, - False, - is_bf16, - ) - - self._spatial_compression_factor = spatial_compression_factor - self._spatial_resolution = spatial_resolution - - @property - def spatial_compression_factor(self): - return self._spatial_compression_factor - - @property - def spatial_resolution(self) -> str: - return self._spatial_resolution - - -class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): - def __init__( - self, - image_vae: torch.nn.Module, - video_vae: torch.nn.Module, - name: str, - latent_ch: int = 16, - squeeze_for_image: bool = True, - ): - super().__init__(latent_ch, name) - self.image_vae = image_vae - self.video_vae = video_vae - self.squeeze_for_image = squeeze_for_image - - def encode_image(self, state: torch.Tensor) -> torch.Tensor: - if self.squeeze_for_image: - return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) - return self.image_vae.encode(state) - - def decode_image(self, latent: torch.Tensor) -> torch.Tensor: - if self.squeeze_for_image: - return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) - return self.image_vae.decode(latent) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - B, C, T, H, W = state.shape - if T == 1: - return self.encode_image(state) - - return self.video_vae.encode(state) - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - B, C, T, H, W = latent.shape - if T == 1: - return self.decode_image(latent) - return self.video_vae.decode(latent) - - def reset_dtype(self, *args, **kwargs): - """ - Resets the data type of the encoder and decoder to the model's default data type. - - Args: - *args, **kwargs: Unused, present to allow flexibility in method calls. - """ - del args, kwargs - self.video_vae.reset_dtype() - - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - if num_pixel_frames == 1: - return 1 - return self.video_vae.get_latent_num_frames(num_pixel_frames) - - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - if num_latent_frames == 1: - return 1 - return self.video_vae.get_pixel_num_frames(num_latent_frames) - - @property - def spatial_compression_factor(self): - return self.video_vae.spatial_compression_factor - - @property - def temporal_compression_factor(self): - return self.video_vae.temporal_compression_factor - - @property - def spatial_resolution(self) -> str: - return self.video_vae.spatial_resolution - - @property - def pixel_chunk_duration(self) -> int: - return self.video_vae.pixel_chunk_duration - - @property - def latent_chunk_duration(self) -> int: - return self.video_vae.latent_chunk_duration - - -class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): - """ - First version of the ImageVideoVAE trained with Fitsum. - We have to use seperate mean and std for image and video due to non-causal nature of the model. - """ - - def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): - super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) - assert isinstance(image_vae, JITVAE) - assert isinstance( - video_vae, VideoJITTokenizer - ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" - - def load_weights(self, vae_dir: str): - self.video_vae.register_mean_std(vae_dir) - - self.video_vae.load_decoder(vae_dir) - self.video_vae.load_encoder(vae_dir) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/module/timm.py b/cosmos-transfer1/cosmos_transfer1/diffusion/module/timm.py deleted file mode 100644 index ebe6dfe88a4afead7de85133fb74bbee181d7f49..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/module/timm.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import warnings - -import torch - - -def _no_grad_trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - with torch.no_grad(): - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - tensor.clamp_(min=a, max=b) - return tensor - - -def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): - # type: (Tensor, float, float, float, float) -> Tensor - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - Examples: - >>> w = torch.empty(3, 5) - >>> nn.init.trunc_normal_(w) - """ - return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/networks/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/distill_controlnet_wrapper.py b/cosmos-transfer1/cosmos_transfer1/diffusion/networks/distill_controlnet_wrapper.py deleted file mode 100644 index edb647bf5e7775c3c8647d57bd7df9c8fe0d79d7..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/distill_controlnet_wrapper.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch.distributed import ProcessGroup - -from cosmos_transfer1.diffusion.module.blocks import DITBuildingBlock -from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate - - -class DistillControlNet(torch.nn.Module): - r"""Wrapper class for the control net. - - This class wraps the control net (self.net_ctrl) and the base model (self.base_model) into a single class for distillation purpose. - In distillation, both the control net and the base model are getting updated. - For example, in DMD2, the student and the fake score are instantiated from this class. - - This class also accommodates the forward method of the control net, which requires the base model as an argument and - call the base_model.net. - - Args: - config (Config): Configuration - - """ - - def __init__(self, config): - super().__init__() - - self.cp_group = None - self.net_ctrl = lazy_instantiate(config.net_ctrl) - - class BaseModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.net = lazy_instantiate(config.net) - - self.base_model = BaseModel() - - def forward(self, *args, **kwargs): - # The following line is necessary for their original forward method in class GeneralDITEncoder to work properly. - kwargs["base_model"] = self.base_model - return self.net_ctrl(*args, **kwargs) - - @property - def fsdp_wrap_block_cls(self): - return DITBuildingBlock - - def enable_context_parallel(self, cp_group: ProcessGroup): - self.base_model.net.enable_context_parallel(cp_group) - self.net_ctrl.enable_context_parallel(cp_group) - self.cp_group = cp_group - - def disable_context_parallel(self): - self.base_model.net.disable_context_parallel() - self.net_ctrl.disable_context_parallel() - self.cp_group = None - - def enable_sequence_parallel(self): - self.base_model.net.enable_sequence_parallel() - self.net_ctrl.enable_sequence_parallel() - - def disable_sequence_parallel(self): - self.base_model.net.disable_sequence_parallel() - self.net_ctrl.disable_sequence_parallel() - - def _set_sequence_parallel(self, status: bool): - self.base_model.net._set_sequence_parallel(status) - self.net_ctrl._set_sequence_parallel(status) - - @property - def is_context_parallel_enabled(self): - return (self.base_model.net.cp_group is not None) and (self.net_ctrl.cp_group is not None) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit.py b/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit.py deleted file mode 100644 index fc706a443a313d80dd55fdeb1f9bc23561d75009..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit.py +++ /dev/null @@ -1,611 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. -""" - -from typing import List, Optional, Tuple - -import torch -from einops import rearrange -from torch import nn -from torch.distributed import ProcessGroup, get_process_group_ranks -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.attention import get_normalization -from cosmos_transfer1.diffusion.module.blocks import ( - FinalLayer, - GeneralDITTransformerBlock, - PatchEmbed, - TimestepEmbedding, - Timesteps, -) -from cosmos_transfer1.diffusion.module.position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb -from cosmos_transfer1.utils import log - - -class GeneralDIT(nn.Module): - """ - A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. - - Args: - max_img_h (int): Maximum height of the input images. - max_img_w (int): Maximum width of the input images. - max_frames (int): Maximum number of frames in the video sequence. - in_channels (int): Number of input channels (e.g., RGB channels for color images). - out_channels (int): Number of output channels. - patch_spatial (tuple): Spatial resolution of patches for input processing. - patch_temporal (int): Temporal resolution of patches for input processing. - concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. - block_config (str): Configuration of the transformer block. See Notes for supported block types. - model_channels (int): Base number of channels used throughout the model. - num_blocks (int): Number of transformer blocks. - num_heads (int): Number of heads in the multi-head attention layers. - mlp_ratio (float): Expansion ratio for MLP blocks. - block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). - crossattn_emb_channels (int): Number of embedding channels for cross-attention. - use_cross_attn_mask (bool): Whether to use mask in cross-attention. - pos_emb_cls (str): Type of positional embeddings. - pos_emb_learnable (bool): Whether positional embeddings are learnable. - pos_emb_interpolation (str): Method for interpolating positional embeddings. - affline_emb_norm (bool): Whether to normalize affine embeddings. - use_adaln_lora (bool): Whether to use AdaLN-LoRA. - adaln_lora_dim (int): Dimension for AdaLN-LoRA. - rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. - rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. - rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. - extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. - extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. - extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. - extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. - extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. - - Notes: - Supported block types in block_config: - * cross_attn, ca: Cross attention - * full_attn: Full attention on all flattened tokens - * mlp, ff: Feed forward block - """ - - def __init__( - self, - max_img_h: int, - max_img_w: int, - max_frames: int, - in_channels: int, - out_channels: int, - patch_spatial: tuple, - patch_temporal: int, - concat_padding_mask: bool = True, - # attention settings - block_config: str = "FA-CA-MLP", - model_channels: int = 768, - num_blocks: int = 10, - num_heads: int = 16, - mlp_ratio: float = 4.0, - block_x_format: str = "BTHWD", - # cross attention settings - crossattn_emb_channels: int = 1024, - use_cross_attn_mask: bool = False, - # positional embedding settings - pos_emb_cls: str = "sincos", - pos_emb_learnable: bool = False, - pos_emb_interpolation: str = "crop", - affline_emb_norm: bool = False, # whether or not to normalize the affine embedding - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - rope_h_extrapolation_ratio: float = 1.0, - rope_w_extrapolation_ratio: float = 1.0, - rope_t_extrapolation_ratio: float = 1.0, - extra_per_block_abs_pos_emb: bool = False, - extra_per_block_abs_pos_emb_type: str = "sincos", - extra_h_extrapolation_ratio: float = 1.0, - extra_w_extrapolation_ratio: float = 1.0, - extra_t_extrapolation_ratio: float = 1.0, - layer_mask: list = None, # whether or not a layer is used. For controlnet encoder - ) -> None: - super().__init__() - self.max_img_h = max_img_h - self.max_img_w = max_img_w - self.max_frames = max_frames - self.in_channels = in_channels - self.out_channels = out_channels - self.patch_spatial = patch_spatial - self.patch_temporal = patch_temporal - self.num_heads = num_heads - self.num_blocks = num_blocks - self.model_channels = model_channels - self.use_cross_attn_mask = use_cross_attn_mask - self.concat_padding_mask = concat_padding_mask - self.block_config = block_config - # positional embedding settings - self.pos_emb_cls = pos_emb_cls - self.pos_emb_learnable = pos_emb_learnable - self.pos_emb_interpolation = pos_emb_interpolation - self.affline_emb_norm = affline_emb_norm - self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio - self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio - self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio - self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb - self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() - self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio - self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio - self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio - - self.build_patch_embed() - self.build_pos_embed() - self.cp_group = None - self.block_x_format = block_x_format - self.use_adaln_lora = use_adaln_lora - self.adaln_lora_dim = adaln_lora_dim - self.t_embedder = nn.Sequential( - Timesteps(model_channels), - TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), - ) - - self.blocks = nn.ModuleDict() - - layer_mask = [False] * num_blocks if layer_mask is None else layer_mask - assert ( - len(layer_mask) == num_blocks - ), f"Layer mask length {len(layer_mask)} does not match num_blocks {num_blocks}" - for idx in range(num_blocks): - if layer_mask[idx]: - continue - self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( - x_dim=model_channels, - context_dim=crossattn_emb_channels, - num_heads=num_heads, - block_config=block_config, - mlp_ratio=mlp_ratio, - x_format=self.block_x_format, - use_adaln_lora=use_adaln_lora, - adaln_lora_dim=adaln_lora_dim, - ) - - self.build_decode_head() - if self.affline_emb_norm: - log.debug("Building affine embedding normalization layer") - self.affline_norm = get_normalization("R", model_channels) - else: - self.affline_norm = nn.Identity() - self.initialize_weights() - - def initialize_weights(self): - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize timestep embedding - nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) - if self.t_embedder[1].linear_1.bias is not None: - nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) - nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) - if self.t_embedder[1].linear_2.bias is not None: - nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) - - # Zero-out adaLN modulation layers in DiT blocks: - for transformer_block in self.blocks.values(): - for block in transformer_block.blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - if block.adaLN_modulation[-1].bias is not None: - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - def build_decode_head(self): - self.final_layer = FinalLayer( - hidden_size=self.model_channels, - spatial_patch_size=self.patch_spatial, - temporal_patch_size=self.patch_temporal, - out_channels=self.out_channels, - use_adaln_lora=self.use_adaln_lora, - adaln_lora_dim=self.adaln_lora_dim, - ) - - def build_patch_embed(self): - ( - concat_padding_mask, - in_channels, - patch_spatial, - patch_temporal, - model_channels, - ) = ( - self.concat_padding_mask, - self.in_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - ) - in_channels = in_channels + 1 if concat_padding_mask else in_channels - self.x_embedder = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - ) - - def build_pos_embed(self): - if self.pos_emb_cls == "rope3d": - cls_type = VideoRopePosition3DEmb - else: - raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") - - log.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") - kwargs = dict( - model_channels=self.model_channels, - len_h=self.max_img_h // self.patch_spatial, - len_w=self.max_img_w // self.patch_spatial, - len_t=self.max_frames // self.patch_temporal, - is_learnable=self.pos_emb_learnable, - interpolation=self.pos_emb_interpolation, - head_dim=self.model_channels // self.num_heads, - h_extrapolation_ratio=self.rope_h_extrapolation_ratio, - w_extrapolation_ratio=self.rope_w_extrapolation_ratio, - t_extrapolation_ratio=self.rope_t_extrapolation_ratio, - ) - self.pos_embedder = cls_type( - **kwargs, - ) - - if self.extra_per_block_abs_pos_emb: - assert self.extra_per_block_abs_pos_emb_type in [ - "learnable", - ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" - kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio - kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio - kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio - self.extra_pos_embedder = LearnablePosEmbAxis( - **kwargs, - ) - - def prepare_embedded_sequence( - self, - x_B_C_T_H_W: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. - - Args: - x_B_C_T_H_W (torch.Tensor): video - fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. - If None, a default value (`self.base_fps`) will be used. - padding_mask (Optional[torch.Tensor]): current it is not used - - Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: - - A tensor of shape (B, T, H, W, D) with the embedded sequence. - - An optional positional embedding tensor, returned only if the positional embedding class - (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. - - Notes: - - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. - - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. - - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using - the `self.pos_embedder` with the shape [T, H, W]. - - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the - `self.pos_embedder` with the fps tensor. - - Otherwise, the positional embeddings are generated without considering fps. - """ - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - padding_mask = padding_mask.unsqueeze(2).expand(x_B_C_T_H_W.size(0), -1, x_B_C_T_H_W.size(2), -1, -1) - x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask], dim=1) # [B, C+1, T, H, W] - x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) - - if self.extra_per_block_abs_pos_emb: - extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) - else: - extra_pos_emb = None - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - - return x_B_T_H_W_D, None, extra_pos_emb - - def decoder_head( - self, - x_B_T_H_W_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] - crossattn_mask: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - del crossattn_emb, crossattn_mask - B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape - x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") - x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) - # This is to ensure x_BT_HW_D has the correct shape because - # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). - x_BT_HW_D = x_BT_HW_D.view( - B * T_before_patchify // self.patch_temporal, - H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, - -1, - ) - x_B_D_T_H_W = rearrange( - x_BT_HW_D, - "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", - p1=self.patch_spatial, - p2=self.patch_spatial, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - t=self.patch_temporal, - B=B, - ) - return x_B_D_T_H_W - - def forward_before_blocks( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - """ - del kwargs - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - original_shape = x.shape - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = self.affline_norm(affline_emb_B_D) - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - # For regional contexts - if regional_contexts is not None: - regional_contexts = rearrange(regional_contexts, "B R M D -> R M B D") - - # For region masks (assuming 5D format) - if region_masks is not None: - region_masks = rearrange(region_masks, "B R T H W -> R T H W B") - - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - output = { - "x": x, - "affline_emb_B_D": affline_emb_B_D, - "crossattn_emb": crossattn_emb, - "crossattn_mask": crossattn_mask, - "rope_emb_L_1_1_D": rope_emb_L_1_1_D, - "adaln_lora_B_3D": adaln_lora_B_3D, - "original_shape": original_shape, - "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - "regional_contexts": regional_contexts, - "region_masks": region_masks, - } - return output - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - x_ctrl: Optional[dict] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to - augment condition input, the lvg model will condition on the condition_video_augment_sigma value; - we need forward_before_blocks pass to the forward_before_blocks function. - regional_contexts: Optional list of regional prompt embeddings, each of shape (B, N, D) - region_masks: Optional tensor of region masks of shape (B, R, THW) - """ - - inputs = self.forward_before_blocks( - x=x, - timesteps=timesteps, - crossattn_emb=crossattn_emb, - crossattn_mask=crossattn_mask, - fps=fps, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - condition_video_augment_sigma=condition_video_augment_sigma, - regional_contexts=regional_contexts, - region_masks=region_masks, - **kwargs, - ) - ( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - original_shape, - ) = ( - inputs["x"], - inputs["affline_emb_B_D"], - inputs["crossattn_emb"], - inputs["crossattn_mask"], - inputs["rope_emb_L_1_1_D"], - inputs["adaln_lora_B_3D"], - inputs["original_shape"], - ) - if regional_contexts is not None: - regional_contexts = inputs["regional_contexts"] - if region_masks is not None: - region_masks = inputs["region_masks"] - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - assert ( - x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape - ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" - - for name, block in self.blocks.items(): - assert ( - self.blocks["block0"].x_format == block.x_format - ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" - - x = block( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - if x_ctrl is not None and name in x_ctrl: - x = x + x_ctrl[name] - - x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") - - x_B_D_T_H_W = self.decoder_head( - x_B_T_H_W_D=x_B_T_H_W_D, - emb_B_D=affline_emb_B_D, - crossattn_emb=None, - origin_shape=original_shape, - crossattn_mask=None, - adaln_lora_B_3D=adaln_lora_B_3D, - ) - - return x_B_D_T_H_W - - def enable_context_parallel(self, cp_group: ProcessGroup): - cp_ranks = get_process_group_ranks(cp_group) - cp_size = len(cp_ranks) - # Set these attributes for spliting the data after embedding. - self.cp_group = cp_group - # Set these attributes for computing the loss. - self.cp_size = cp_size - - self.pos_embedder.enable_context_parallel(cp_group) - if self.extra_per_block_abs_pos_emb: - self.extra_pos_embedder.enable_context_parallel(cp_group) - # Loop through the model to set up context parallel. - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["mlp", "ff"]: - continue - elif layer.block_type in ["cross_attn", "ca"]: - continue - else: - layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) - - log.debug(f"[CP] Enable context parallelism with size {cp_size}") - - def disable_context_parallel(self): - self.cp_group = None - self.cp_size = None - - self.pos_embedder.disable_context_parallel() - if self.extra_per_block_abs_pos_emb: - self.extra_pos_embedder.disable_context_parallel() - - # Loop through the model to disable context parallel. - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["mlp", "ff"]: - continue - elif layer.block_type in ["cross_attn", "ca"]: - continue - else: - layer.block.attn.attn_op.cp_group = None - layer.block.attn.attn_op.cp_ranks = None - layer.block.attn.attn_op.cp_stream = None - - log.debug("[CP] Disable context parallelism.") - - @property - def is_context_parallel_enabled(self): - return self.cp_group is not None diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py b/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py deleted file mode 100644 index d47bc49e7a8299fe565e8d926e8da318dfa82eda..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc.py +++ /dev/null @@ -1,344 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -ControlNet Encoder based on GeneralDIT -""" - -from typing import List, Optional, Tuple - -import torch -from einops import rearrange - -# from megatron.core import parallel_state -from torch import nn -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.blocks import PatchEmbed, zero_module -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.networks.general_dit_video_conditioned import VideoExtendGeneralDIT as GeneralDIT - - -class GeneralDITEncoder(GeneralDIT): - """ - ControlNet Encoder based on GeneralDIT. Heavily borrowed from GeneralDIT with minor modifications. - """ - - def __init__(self, *args, **kwargs): - hint_channels = kwargs.pop("hint_channels", 16) - self.dropout_ctrl_branch = kwargs.pop("dropout_ctrl_branch", 0.5) - num_control_blocks = kwargs.pop("num_control_blocks", None) - if num_control_blocks is not None: - assert num_control_blocks > 0 and num_control_blocks <= kwargs["num_blocks"] - kwargs["layer_mask"] = [False] * num_control_blocks + [True] * (kwargs["num_blocks"] - num_control_blocks) - self.random_drop_control_blocks = kwargs.pop("random_drop_control_blocks", False) - super().__init__(*args, **kwargs) - num_blocks = self.num_blocks - model_channels = self.model_channels - layer_mask = kwargs.get("layer_mask", None) - layer_mask = [False] * num_blocks if layer_mask is None else layer_mask - self.layer_mask = layer_mask - self.hint_channels = hint_channels - self.build_hint_patch_embed() - hint_nf = [16, 16, 32, 32, 96, 96, 256] - nonlinearity = nn.SiLU() - input_hint_block = [nn.Linear(model_channels, hint_nf[0]), nonlinearity] - for i in range(len(hint_nf) - 1): - input_hint_block += [nn.Linear(hint_nf[i], hint_nf[i + 1]), nonlinearity] - self.input_hint_block = nn.Sequential(*input_hint_block) - # Initialize weights - self.initialize_weights() - self.zero_blocks = nn.ModuleDict() - for idx in range(num_blocks): - if layer_mask[idx]: - continue - self.zero_blocks[f"block{idx}"] = zero_module(nn.Linear(model_channels, model_channels)) - self.input_hint_block.append(zero_module(nn.Linear(hint_nf[-1], model_channels))) - - def build_hint_patch_embed(self): - concat_padding_mask, in_channels, patch_spatial, patch_temporal, model_channels = ( - self.concat_padding_mask, - self.hint_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - ) - in_channels = in_channels + 1 if concat_padding_mask else in_channels - self.x_embedder2 = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - ) - - def prepare_hint_embedded_sequence( - self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - padding_mask = padding_mask.unsqueeze(2).expand(x_B_C_T_H_W.size(0), -1, x_B_C_T_H_W.size(2), -1, -1) - x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask], dim=1) # [B, C+1, T, H, W] - - x_B_T_H_W_D = self.x_embedder2(x_B_C_T_H_W) - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps) - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None - - def encode_hint( - self, - hint: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - ) -> torch.Tensor: - assert hint.size(1) <= self.hint_channels, f"Expected hint channels <= {self.hint_channels}, got {hint.size(1)}" - if hint.size(1) < self.hint_channels: - padding_shape = list(hint.shape) - padding_shape[1] = self.hint_channels - hint.size(1) - hint = torch.cat([hint, torch.zeros(*padding_shape, dtype=hint.dtype, device=hint.device)], dim=1) - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) - hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") - - guided_hint = self.input_hint_block(hint) - return guided_hint - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - hint_key: Optional[str] = None, - base_model: Optional[nn.Module] = None, - control_weight: Optional[float] = 1.0, - num_layers_to_use: Optional[int] = -1, - condition_video_input_mask: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - regional_contexts: Optional list of regional context tensors - region_masks: Optional tensor of region masks - """ - # record the input as they are replaced in this forward - x_input = x - crossattn_emb_input = crossattn_emb - crossattn_mask_input = crossattn_mask - condition_video_input_mask_input = condition_video_input_mask - - regional_contexts_input = regional_contexts - region_masks_input = region_masks - - hint = kwargs.pop(hint_key) - if hint is None: - print("using none hint") - return base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_input_mask=condition_video_input_mask_input, - regional_contexts=regional_contexts_input, - region_masks=region_masks_input, - **kwargs, - ) - if hasattr(self, "hint_encoders"): # for multicontrol - guided_hints = [] - for i in range(hint.shape[1]): - self.input_hint_block = self.hint_encoders[i].input_hint_block - self.pos_embedder = self.hint_encoders[i].pos_embedder - self.x_embedder2 = self.hint_encoders[i].x_embedder2 - guided_hints += [self.encode_hint(hint[:, i], fps=fps, padding_mask=padding_mask, data_type=data_type)] - else: - guided_hints = self.encode_hint(hint, fps=fps, padding_mask=padding_mask, data_type=data_type) - guided_hints = torch.chunk(guided_hints, hint.shape[0] // x.shape[0], dim=3) - # Only support multi-control at inference time - assert len(guided_hints) == 1 or not torch.is_grad_enabled() - - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - B, C, T, H, W = x.shape - if data_type == DataType.VIDEO: - if condition_video_input_mask is not None: - if self.cp_group is not None: - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=2, cp_group=self.cp_group - ) - input_list = [x, condition_video_input_mask] - x = torch.cat(input_list, dim=1) - elif data_type == DataType.IMAGE: - # For image, we dont have condition_video_input_mask, or condition_video_pose - # We need to add the extra channel for video condition mask - padding_channels = self.in_channels - x.shape[1] - x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) - else: - assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" - - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - if regional_contexts is not None: - regional_contexts = rearrange(regional_contexts, "B R M D -> R M B D") - if region_masks is not None: - region_masks = rearrange(region_masks, "B R T H W -> R T H W B") - - outs = {} - - num_control_blocks = self.layer_mask.index(True) - num_layers_to_use = num_control_blocks - control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] - - if isinstance(control_weight, torch.Tensor): - if control_weight.ndim == 0: # Single scalar tensor - control_weight = [float(control_weight)] - elif control_weight.ndim == 1: # List of scalar weights - control_weight = [float(w) for w in control_weight] - else: # Spatial-temporal weight maps - if self.cp_group is not None: - control_weight = split_inputs_cp(control_weight, seq_dim=3, cp_group=self.cp_group) - control_weight = [w for w in control_weight] # Keep as tensor - else: - control_weight = [control_weight] * len(guided_hints) - - x_before_blocks = x.clone() - for i, guided_hint in enumerate(guided_hints): - x = x_before_blocks - if hasattr(self, "hint_encoders"): # for multicontrol - blocks = self.hint_encoders[i].blocks - zero_blocks = self.hint_encoders[i].zero_blocks - t_embedder = self.hint_encoders[i].t_embedder - affline_norm = self.hint_encoders[i].affline_norm - self.x_embedder = self.hint_encoders[i].x_embedder - self.extra_pos_embedder = self.hint_encoders[i].extra_pos_embedder - else: - blocks = self.blocks - zero_blocks = self.zero_blocks - t_embedder = self.t_embedder - affline_norm = self.affline_norm - - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, fps=fps, padding_mask=padding_mask - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = affline_norm(affline_emb_B_D) - - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - - for idx, (name, block) in enumerate(blocks.items()): - assert ( - blocks["block0"].x_format == block.x_format - ), f"First block has x_format {blocks[0].x_format}, got {block.x_format}" - x = block( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - if guided_hint is not None: - x = x + guided_hint - guided_hint = None - - gate = control_gate_per_layer[idx] - if isinstance(control_weight[i], (float, int)) or control_weight[i].ndim < 2: - hint_val = zero_blocks[name](x) * control_weight[i] * gate - else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] - control_feat = zero_blocks[name](x) - # Get current feature dimensions - weight_map = control_weight[i] # [B, 1, T, H, W] - # Reshape to match THWBD format - weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] - hint_val = control_feat * weight_map * gate - if name not in outs: - outs[name] = hint_val - else: - outs[name] += hint_val - - output = base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - x_ctrl=outs, - condition_video_input_mask=condition_video_input_mask_input, - regional_contexts=regional_contexts_input, - region_masks=region_masks_input, - **kwargs, - ) - return output diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc_multiview.py b/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc_multiview.py deleted file mode 100644 index 1decb26a06bf9bc7c3ba54547c204d7f547dec93..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_ctrl_enc_multiview.py +++ /dev/null @@ -1,407 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -ControlNet Encoder based on GeneralDIT -""" - -from typing import List, Optional, Tuple - -import numpy as np -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import nn -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.attention import get_normalization -from cosmos_transfer1.diffusion.module.blocks import ( - FinalLayer, - GeneralDITTransformerBlock, - PatchEmbed, - TimestepEmbedding, - Timesteps, - zero_module, -) -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.module.position_embedding import ( - MultiCameraSinCosPosEmbAxis, - MultiCameraVideoRopePosition3DEmb, -) -from cosmos_transfer1.diffusion.networks.general_dit_multi_view import MultiViewGeneralDIT -from cosmos_transfer1.utils import log - - -class GeneralDITMultiviewEncoder(MultiViewGeneralDIT): - """ - ControlNet Encoder based on MultiViewGeneralDIT. Based on GeneralDITEncoder - """ - - def __init__(self, *args, in_channels, is_extend_model=False, **kwargs): - self.is_extend_model = is_extend_model - if is_extend_model: - new_input_channels = in_channels + 1 - log.info(f"Updating input channels to {new_input_channels} to accomodate cond_mask") - else: - new_input_channels = in_channels - - if kwargs.get("add_augment_sigma_embedding", None) is not None: - self.add_augment_sigma_embedding = kwargs.pop("add_augment_sigma_embedding") - else: - self.add_augment_sigma_embedding = False - hint_channels = kwargs.pop("hint_channels", 16) - self.dropout_ctrl_branch = kwargs.pop("dropout_ctrl_branch", 0.5) - num_control_blocks = kwargs.pop("num_control_blocks", None) - if num_control_blocks is not None: - assert num_control_blocks > 0 and num_control_blocks <= kwargs["num_blocks"] - kwargs["layer_mask"] = [False] * num_control_blocks + [True] * (kwargs["num_blocks"] - num_control_blocks) - self.random_drop_control_blocks = kwargs.pop("random_drop_control_blocks", False) - super().__init__(*args, in_channels=new_input_channels, **kwargs) - num_blocks = self.num_blocks - model_channels = self.model_channels - layer_mask = kwargs.get("layer_mask", None) - layer_mask = [False] * num_blocks if layer_mask is None else layer_mask - self.layer_mask = layer_mask - self.hint_channels = hint_channels - self.build_hint_patch_embed() - hint_nf = [16, 16, 32, 32, 96, 96, 256] - nonlinearity = nn.SiLU() - input_hint_block = [nn.Linear(model_channels, hint_nf[0]), nonlinearity] - for i in range(len(hint_nf) - 1): - input_hint_block += [nn.Linear(hint_nf[i], hint_nf[i + 1]), nonlinearity] - self.input_hint_block = nn.Sequential(*input_hint_block) - # Initialize weights - self.initialize_weights() - self.zero_blocks = nn.ModuleDict() - for idx in range(num_blocks): - if layer_mask[idx]: - continue - self.zero_blocks[f"block{idx}"] = zero_module(nn.Linear(model_channels, model_channels)) - self.input_hint_block.append(zero_module(nn.Linear(hint_nf[-1], model_channels))) - - def _set_sequence_parallel(self, status: bool): - self.zero_blocks.sequence_parallel = status - self.input_hint_block.sequence_parallel = status - super()._set_sequence_parallel(status) - - def build_hint_patch_embed(self): - concat_padding_mask, in_channels, patch_spatial, patch_temporal, model_channels = ( - self.concat_padding_mask, - self.hint_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - ) - in_channels = in_channels + 1 if concat_padding_mask else in_channels - self.x_embedder2 = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - ) - - def prepare_hint_embedded_sequence( - self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[2], 1, 1)], - dim=1, - ) - - x_B_T_H_W_D = self.x_embedder2(x_B_C_T_H_W) - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps) - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None - - def encode_hint( - self, - hint: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - ) -> torch.Tensor: - assert hint.size(1) <= self.hint_channels, f"Expected hint channels <= {self.hint_channels}, got {hint.size(1)}" - if hint.size(1) < self.hint_channels: - padding_shape = list(hint.shape) - padding_shape[1] = self.hint_channels - hint.size(1) - hint = torch.cat([hint, torch.zeros(*padding_shape, dtype=hint.dtype, device=hint.device)], dim=1) - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) - - if self.blocks["block0"].x_format == "THWBD": - hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - - guided_hint = self.input_hint_block(hint) - return guided_hint - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - hint_key: Optional[str] = None, - base_model: Optional[nn.Module] = None, - control_weight: Optional[float] = 1.0, - num_layers_to_use: Optional[int] = -1, - condition_video_input_mask: Optional[torch.Tensor] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - """ - # record the input as they are replaced in this forward - x_input = x - frame_repeat = kwargs.get("frame_repeat", None) - original_shape = x.shape - crossattn_emb_input = crossattn_emb - crossattn_mask_input = crossattn_mask - condition_video_input_mask_input = condition_video_input_mask - hint = kwargs.pop(hint_key) - if hint is None: - log.info("using none hint") - return base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_input_mask=condition_video_input_mask_input, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - view_indices_B_T=view_indices_B_T, - **kwargs, - ) - - if hasattr(self, "hint_encoders"): # for multicontrol - guided_hints = [] - for i in range(hint.shape[1]): - self.input_hint_block = self.hint_encoders[i].input_hint_block - self.pos_embedder = self.hint_encoders[i].pos_embedder - self.x_embedder2 = self.hint_encoders[i].x_embedder2 - guided_hints += [self.encode_hint(hint[:, i], fps=fps, padding_mask=padding_mask, data_type=data_type)] - else: - guided_hints = self.encode_hint(hint, fps=fps, padding_mask=padding_mask, data_type=data_type) - guided_hints = torch.chunk(guided_hints, hint.shape[0] // x.shape[0], dim=3) - # Only support multi-control at inference time - assert len(guided_hints) == 1 or not torch.is_grad_enabled() - - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - B, C, T, H, W = x.shape - if data_type == DataType.VIDEO: - if condition_video_input_mask is not None: - if self.cp_group is not None: - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views - ) - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=3, cp_group=self.cp_group - ) - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views - ) - input_list = [x, condition_video_input_mask] - x = torch.cat( - input_list, - dim=1, - ) - - elif data_type == DataType.IMAGE: - # For image, we dont have condition_video_input_mask, or condition_video_pose - # We need to add the extra channel for video condition mask - padding_channels = self.in_channels - x.shape[1] - x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) - else: - assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" - - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - if self.blocks["block0"].x_format == "THWBD": - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - outs = {} - - num_control_blocks = self.layer_mask.index(True) - num_layers_to_use = num_control_blocks - control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] - - if isinstance(control_weight, torch.Tensor): - if control_weight.ndim == 0: # Single scalar tensor - control_weight = [float(control_weight)] * len(guided_hints) - elif control_weight.ndim == 1: # List of scalar weights - control_weight = [float(w) for w in control_weight] - else: # Spatial-temporal weight maps - if self.cp_group is not None: - control_weight = split_inputs_cp(control_weight, seq_dim=3, cp_group=self.cp_group) - control_weight = [w for w in control_weight] # Keep as tensor - else: - control_weight = [control_weight] * len(guided_hints) - - # max_norm = {} - x_before_blocks = x.clone() - for i, guided_hint in enumerate(guided_hints): - x = x_before_blocks - if hasattr(self, "hint_encoders"): # for multicontrol - blocks = self.hint_encoders[i].blocks - zero_blocks = self.hint_encoders[i].zero_blocks - t_embedder = self.hint_encoders[i].t_embedder - affline_norm = self.hint_encoders[i].affline_norm - self.x_embedder = self.hint_encoders[i].x_embedder - self.extra_pos_embedder = self.hint_encoders[i].extra_pos_embedder - else: - blocks = self.blocks - zero_blocks = self.zero_blocks - t_embedder = self.t_embedder - affline_norm = self.affline_norm - - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - frame_repeat=frame_repeat, - view_indices_B_T=view_indices_B_T, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = affline_norm(affline_emb_B_D) - - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - - for idx, (name, block) in enumerate(blocks.items()): - assert ( - blocks["block0"].x_format == block.x_format - ), f"First block has x_format {blocks[0].x_format}, got {block.x_format}" - x = block( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - ) - if guided_hint is not None: - x = x + guided_hint - guided_hint = None - - gate = control_gate_per_layer[idx] - if isinstance(control_weight[i], (float, int)) or control_weight[i].ndim < 2: - hint_val = zero_blocks[name](x) * control_weight[i] * gate # coin_flip * gate - else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] - control_feat = zero_blocks[name](x) - - # Get current feature dimensions - if self.blocks["block0"].x_format == "THWBD": - weight_map = control_weight[i] # [B, 1, T, H, W] - # Reshape to match THWBD format - weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] - weight_map = weight_map.view(T * H * W, 1, 1, B, 1) - - else: # BTHWD format - raise NotImplementedError("BTHWD format for weight map is not implemented yet.") - hint_val = control_feat * weight_map * gate # * coin_flip - - if name not in outs: - outs[name] = hint_val - else: - outs[name] += hint_val - - output = base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - x_ctrl=outs, - condition_video_input_mask=condition_video_input_mask_input, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - view_indices_B_T=view_indices_B_T, - **kwargs, - ) - return output diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_multi_view.py b/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_multi_view.py deleted file mode 100644 index 729c8e824ff7bac7843d283f3069921bb65390b6..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_multi_view.py +++ /dev/null @@ -1,442 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import DictType, nn -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.attention import get_normalization -from cosmos_transfer1.diffusion.module.blocks import ( - FinalLayer, - GeneralDITTransformerBlock, - PatchEmbed, - TimestepEmbedding, - Timesteps, -) -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.module.position_embedding import ( - MultiCameraSinCosPosEmbAxis, - MultiCameraVideoRopePosition3DEmb, -) -from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT -from cosmos_transfer1.utils import log - - -class MultiViewGeneralDIT(GeneralDIT): - def __init__( - self, - *args, - n_views: int = 3, - n_views_emb: int = -1, - camera_condition_dim: int = 3, - traj_condition_dim: int = 3, - concat_camera_embedding: bool = True, - concat_traj_embedding: bool = False, - add_repeat_frame_embedding: bool = False, - **kwargs, - ): - if kwargs.get("add_augment_sigma_embedding", None) is not None: - self.add_augment_sigma_embedding = kwargs.pop("add_augment_sigma_embedding") - - self.n_views = n_views - if n_views_emb < 0: - self.n_views_emb = n_views - else: - self.n_views_emb = n_views_emb - - self.camera_condition_dim = camera_condition_dim - self.traj_condition_dim = traj_condition_dim - self.concat_camera_embedding = concat_camera_embedding - self.concat_traj_embedding = concat_traj_embedding - self.add_repeat_frame_embedding = add_repeat_frame_embedding - - super().__init__(*args, **kwargs) - # reinit self.blocks - del self.blocks - self.blocks = nn.ModuleDict() - - layer_mask = [False] * self.num_blocks if kwargs["layer_mask"] is None else kwargs["layer_mask"] - assert ( - len(layer_mask) == self.num_blocks - ), f"Layer mask length {len(layer_mask)} does not match num_blocks { self.num_blocks}" - for idx in range(self.num_blocks): - if layer_mask[idx]: - continue - self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( - x_dim=self.model_channels, - context_dim=kwargs["crossattn_emb_channels"], - num_heads=self.num_heads, - block_config=self.block_config, - mlp_ratio=kwargs["mlp_ratio"], - x_format=self.block_x_format, - use_adaln_lora=self.use_adaln_lora, - adaln_lora_dim=self.adaln_lora_dim, - n_views=self.n_views, - ) - self.view_embeddings = nn.Embedding(self.n_views_emb, camera_condition_dim) # Learnable embedding layer - - if self.concat_traj_embedding: - self.traj_embeddings = nn.Linear(192, self.traj_condition_dim) # Learnable embedding layer - if self.add_repeat_frame_embedding: - self.repeat_frame_embedding = nn.Linear(1, camera_condition_dim) # Learnable embedding layer - - self.initialize_weights() - - def build_patch_embed(self): - ( - concat_padding_mask, - in_channels, - patch_spatial, - patch_temporal, - model_channels, - camera_condition_dim, - traj_condition_dim, - ) = ( - self.concat_padding_mask, - self.in_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - self.camera_condition_dim, - self.traj_condition_dim, - ) - if self.concat_camera_embedding: - in_channels = in_channels + camera_condition_dim if camera_condition_dim > 0 else in_channels - - if self.concat_traj_embedding: - in_channels = in_channels + traj_condition_dim if traj_condition_dim > 0 else in_channels - - in_channels = in_channels + 1 if concat_padding_mask else in_channels - - self.x_embedder = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - ) - - def build_pos_embed(self): - if self.pos_emb_cls == "rope3d": - cls_type = MultiCameraVideoRopePosition3DEmb - else: - raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") - - log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") - kwargs = dict( - model_channels=self.model_channels, - len_h=self.max_img_h // self.patch_spatial, - len_w=self.max_img_w // self.patch_spatial, - len_t=self.max_frames // self.patch_temporal, - is_learnable=self.pos_emb_learnable, - interpolation=self.pos_emb_interpolation, - head_dim=self.model_channels // self.num_heads, - h_extrapolation_ratio=self.rope_h_extrapolation_ratio, - w_extrapolation_ratio=self.rope_w_extrapolation_ratio, - t_extrapolation_ratio=self.rope_t_extrapolation_ratio, - n_views=self.n_views, - ) - self.pos_embedder = cls_type( - **kwargs, - ) - - if self.extra_per_block_abs_pos_emb: - kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio - kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio - kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio - self.extra_pos_embedder = MultiCameraSinCosPosEmbAxis( - **kwargs, - ) - - def forward_before_blocks( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - **kwargs, - ) -> dict: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - """ - trajectory = kwargs.get("trajectory", None) - frame_repeat = kwargs.get("frame_repeat", None) - - del kwargs - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - original_shape = x.shape - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - trajectory=trajectory, - frame_repeat=frame_repeat, - view_indices_B_T=view_indices_B_T, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = self.affline_norm(affline_emb_B_D) - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - output = { - "x": x, - "affline_emb_B_D": affline_emb_B_D, - "crossattn_emb": crossattn_emb, - "crossattn_mask": crossattn_mask, - "rope_emb_L_1_1_D": rope_emb_L_1_1_D, - "adaln_lora_B_3D": adaln_lora_B_3D, - "original_shape": original_shape, - "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - "regional_contexts": regional_contexts, - "region_masks": region_masks, - } - return output - - def prepare_embedded_sequence( - self, - x_B_C_T_H_W: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - trajectory: Optional[torch.Tensor] = None, - frame_repeat: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - """ - Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. - - Args: - x_B_C_T_H_W (torch.Tensor): video - fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. - If None, a default value (`self.base_fps`) will be used. - padding_mask (Optional[torch.Tensor]): current it is not used - - Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: - - A tensor of shape (B, T, H, W, D) with the embedded sequence. - - An optional positional embedding tensor, returned only if the positional embedding class - (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. - - Notes: - - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. - - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. - - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using - the `self.pos_embedder` with the shape [T, H, W]. - - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` - with the fps tensor. - - Otherwise, the positional embeddings are generated without considering fps. - """ - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) - - if view_indices_B_T is None: - view_indices = torch.arange(self.n_views).clamp(max=self.n_views_emb - 1) # View indices [0, 1, ..., V-1] - view_indices = view_indices.to(x_B_C_T_H_W.device) - view_embedding = self.view_embeddings(view_indices) # Shape: [V, embedding_dim] - view_embedding = rearrange(view_embedding, "V D -> D V") - view_embedding = ( - view_embedding.unsqueeze(0).unsqueeze(3).unsqueeze(4).unsqueeze(5) - ) # Shape: [1, D, V, 1, 1, 1] - else: - view_indices_B_T = view_indices_B_T.clamp(max=self.n_views_emb - 1) - view_indices_B_T = view_indices_B_T.to(x_B_C_T_H_W.device).long() - view_embedding = self.view_embeddings(view_indices_B_T) # B, (V T), D - view_embedding = rearrange(view_embedding, "B (V T) D -> B D V T", V=self.n_views) - view_embedding = view_embedding.unsqueeze(-1).unsqueeze(-1) # Shape: [B, D, V, T, 1, 1] - - if self.add_repeat_frame_embedding: - if frame_repeat is None: - frame_repeat = ( - torch.zeros([x_B_C_T_H_W.shape[0], view_embedding.shape[1]]) - .to(view_embedding.device) - .to(view_embedding.dtype) - ) - frame_repeat_embedding = self.repeat_frame_embedding(frame_repeat.unsqueeze(-1)) - frame_repeat_embedding = rearrange(frame_repeat_embedding, "B V D -> B D V") - view_embedding = view_embedding + frame_repeat_embedding.unsqueeze(3).unsqueeze(4).unsqueeze(5) - - x_B_C_V_T_H_W = rearrange(x_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=self.n_views) - view_embedding = view_embedding.expand( - x_B_C_V_T_H_W.shape[0], - view_embedding.shape[1], - view_embedding.shape[2], - x_B_C_V_T_H_W.shape[3], - x_B_C_V_T_H_W.shape[4], - x_B_C_V_T_H_W.shape[5], - ) # Shape: [B, V, 3, t, H, W] - if self.concat_traj_embedding: - traj_emb = self.traj_embeddings(trajectory) - traj_emb = traj_emb.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) - traj_emb = traj_emb.expand( - x_B_C_V_T_H_W.shape[0], - traj_emb.shape[1], - view_embedding.shape[2], - x_B_C_V_T_H_W.shape[3], - x_B_C_V_T_H_W.shape[4], - x_B_C_V_T_H_W.shape[5], - ) # Shape: [B, V, 3, t, H, W] - - x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding, traj_emb], dim=1) - else: - x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding], dim=1) - - x_B_C_T_H_W = rearrange(x_B_C_V_T_H_W, " B C V T H W -> B C (V T) H W", V=self.n_views) - x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) - - if self.extra_per_block_abs_pos_emb: - extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) - else: - extra_pos_emb = None - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None, extra_pos_emb - - -class MultiViewVideoExtendGeneralDIT(MultiViewGeneralDIT): - def __init__(self, *args, in_channels=17, add_augment_sigma_embedding=False, **kwargs): - self.add_augment_sigma_embedding = add_augment_sigma_embedding - # extra channel for video condition mask - super().__init__(*args, in_channels=in_channels, **kwargs) - log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - video_cond_bool: Optional[torch.Tensor] = None, - condition_video_indicator: Optional[torch.Tensor] = None, - condition_video_input_mask: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - condition_video_pose: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Args: - condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation - condition_video_pose: (B, 1, T, H, W) tensor of pose condition - """ - B, C, T, H, W = x.shape - - if data_type == DataType.VIDEO: - assert ( - condition_video_input_mask is not None - ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" - if parallel_state.is_initialized(): - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views - ) - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=3, cp_group=self.cp_group - ) - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views - ) - input_list = [x, condition_video_input_mask] - if condition_video_pose is not None: - if condition_video_pose.shape[2] > T: - log.warning( - f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" - ) - condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() - input_list.append(condition_video_pose) - x = torch.cat( - input_list, - dim=1, - ) - - return super().forward( - x=x, - timesteps=timesteps, - crossattn_emb=crossattn_emb, - crossattn_mask=crossattn_mask, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_augment_sigma=condition_video_augment_sigma, - view_indices_B_T=view_indices_B_T, - **kwargs, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py b/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py deleted file mode 100644 index 98bbf6ed76bfb0e014302b77bf812f91f2737608..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/networks/general_dit_video_conditioned.py +++ /dev/null @@ -1,222 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import nn - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.blocks import TimestepEmbedding, Timesteps -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.networks.general_dit import GeneralDIT -from cosmos_transfer1.utils import log - - -class VideoExtendGeneralDIT(GeneralDIT): - def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): - self.add_augment_sigma_embedding = add_augment_sigma_embedding - - # extra channel for video condition mask - super().__init__(*args, in_channels=in_channels, **kwargs) - log.debug(f"VideoExtendGeneralDIT in_channels: {in_channels}") - - def build_additional_timestamp_embedder(self): - super().build_additional_timestamp_embedder() - if self.add_augment_sigma_embedding: - log.info("Adding augment sigma embedding") - self.augment_sigma_embedder = nn.Sequential( - Timesteps(self.model_channels), - TimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), - ) - - def initialize_weights(self): - if self.add_augment_sigma_embedding: - # Initialize timestep embedding for augment sigma - nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) - if self.augment_sigma_embedder[1].linear_1.bias is not None: - nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) - nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) - if self.augment_sigma_embedder[1].linear_2.bias is not None: - nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) - - super().initialize_weights() # Call this last since it wil call TP weight init - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - video_cond_bool: Optional[torch.Tensor] = None, - condition_video_indicator: Optional[torch.Tensor] = None, - condition_video_input_mask: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Forward pass of the video-conditioned DIT model. - - Args: - x: Input tensor of shape (B, C, T, H, W) - timesteps: Timestep tensor of shape (B,) - crossattn_emb: Cross attention embeddings of shape (B, N, D) - crossattn_mask: Optional cross attention mask of shape (B, N) - fps: Optional frames per second tensor - padding_mask: Optional padding mask tensor - scalar_feature: Optional scalar features tensor - data_type: Type of data being processed (default: DataType.VIDEO) - video_cond_bool: Optional video conditioning boolean tensor - condition_video_indicator: Optional video condition indicator tensor - condition_video_input_mask: Required mask tensor for video data type - condition_video_augment_sigma: Optional sigma values for conditional input augmentation - **kwargs: Additional keyword arguments - - Returns: - torch.Tensor: Output tensor - """ - B, C, T, H, W = x.shape - - if data_type == DataType.VIDEO: - assert condition_video_input_mask is not None, "condition_video_input_mask is required for video data type" - - if parallel_state.is_initialized(): - cp_group = parallel_state.get_context_parallel_group() - condition_video_input_mask = split_inputs_cp(condition_video_input_mask, seq_dim=2, cp_group=cp_group) - condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) - input_list = [x, condition_video_input_mask] - x = torch.cat( - input_list, - dim=1, - ) - - return super().forward( - x=x, - timesteps=timesteps, - crossattn_emb=crossattn_emb, - crossattn_mask=crossattn_mask, - fps=fps, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_augment_sigma=condition_video_augment_sigma, - regional_contexts=regional_contexts, - region_masks=region_masks, - **kwargs, - ) - - def forward_before_blocks( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - - condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation - """ - del kwargs - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - original_shape = x.shape - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - - if self.add_augment_sigma_embedding: - if condition_video_augment_sigma is None: - # Handling image case - # Note: for video case, when there is not condition frames, we also set it as zero, see extend_model augment_conditional_latent_frames function - assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" - condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) - - affline_augment_sigma_emb_B_D, _ = self.augment_sigma_embedder(condition_video_augment_sigma.flatten()) - affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = self.affline_norm(affline_emb_B_D) - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - # For regional contexts - if regional_contexts is not None: - regional_contexts = rearrange(regional_contexts, "B R M D -> R M B D") - - # For region masks (assuming 5D format) - if region_masks is not None: - # if len(region_masks.shape) == 5: - region_masks = rearrange(region_masks, "B R T H W -> R T H W B") - - output = { - "x": x, - "affline_emb_B_D": affline_emb_B_D, - "crossattn_emb": crossattn_emb, - "crossattn_mask": crossattn_mask, - "rope_emb_L_1_1_D": rope_emb_L_1_1_D, - "adaln_lora_B_3D": adaln_lora_B_3D, - "original_shape": original_shape, - "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - "regional_contexts": regional_contexts, - "region_masks": region_masks, - } - return output diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/README.md b/cosmos-transfer1/cosmos_transfer1/diffusion/training/README.md deleted file mode 100644 index 6f41fcb8ee64b3ed7638144f6061a98a3cc65b8d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Training Modules - -This folder contains specialized versions of models and modules optimized for training. While some components (for example, the `GeneralDIT` defined in `training/networks/general_dit.py`) may appear duplicated from elsewhere in the repository, they include training-specific functionality including gradient checkpointing, training steps, tensor parallel and sequence parallel support, etc. diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/every_n.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/every_n.py deleted file mode 100644 index bf6d5e6e410dc68c3bbf3a4efd050e766ab1eb27..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/every_n.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import abstractmethod -from typing import Optional - -import torch - -from cosmos_transfer1.utils import distributed, log -from cosmos_transfer1.utils.callback import Callback -from cosmos_transfer1.utils.model import Model -from cosmos_transfer1.utils.trainer import Trainer - - -class EveryN(Callback): - def __init__( - self, - every_n: Optional[int] = None, - step_size: int = 1, - barrier_after_run: bool = True, - run_at_start: bool = False, - ) -> None: - """Constructor for `EveryN`. - - Args: - every_n (int): Frequency with which callback is run during training. - step_size (int): Size of iteration step count. Default 1. - barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts. - run_at_start (bool): Whether to run at the beginning of training. Default False. - """ - self.every_n = every_n - if self.every_n == 0: - log.warning( - f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped." - ) - - self.step_size = step_size - self.barrier_after_run = barrier_after_run - self.run_at_start = run_at_start - - def on_training_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training - if self.every_n != 0: - trainer = self.trainer - global_step = iteration // self.step_size - should_run = (iteration == 1 and self.run_at_start) or ( - global_step % self.every_n == 0 - ) # (self.every_n - 1) - if should_run: - log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}") - self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration) - log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}") - # add necessary barrier to avoid timeout - if self.barrier_after_run: - distributed.barrier() - - @abstractmethod - def every_n_impl( - self, - trainer: Trainer, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int, - ) -> None: - ... diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/grad_clip.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/grad_clip.py deleted file mode 100644 index 9070fca91514a61f7f40b53e9db38a14567de8b4..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/grad_clip.py +++ /dev/null @@ -1,102 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Tuple - -import torch -from megatron.core import parallel_state -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from cosmos_transfer1.utils import distributed -from cosmos_transfer1.utils.callback import GradClip as GradClipImage -from cosmos_transfer1.utils.callback import _fused_nan_to_num -from cosmos_transfer1.utils.model import Model - - -@dataclass -class _MagnitudeRecord: - state: float = 0 - iter_count: int = 0 - - def reset(self) -> None: - self.state = 0 - self.iter_count = 0 - - def update(self, cur_state: torch.Tensor) -> None: - self.state += cur_state - self.iter_count += 1 - - def get_stat(self) -> Tuple[float, float]: - if self.iter_count > 0: - avg_state = self.state / self.iter_count - avg_state = avg_state.item() - else: - avg_state = 0 - self.reset() - return avg_state - - -class GradClip(GradClipImage): - """ - adds support for TP - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.img_mag_log = _MagnitudeRecord() - self.video_mag_log = _MagnitudeRecord() - self._cur_state = None - - def on_training_step_start(self, model: Model, data_batch: dict[str, torch.Tensor], iteration: int = 0) -> None: - if model.is_image_batch(data_batch): - self._cur_state = self.img_mag_log - else: - self._cur_state = self.video_mag_log - - def on_before_optimizer_step( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int = 0, - ) -> None: - del optimizer, scheduler - if isinstance(model_ddp, distributed.DistributedDataParallel): - model = model_ddp.module - else: - model = model_ddp - params = [] - if self.model_key is not None: - items = self.model_key.split(".") - for item in items: - model = getattr(model, item) - if self.force_finite: - for param in model.parameters(): - if param.grad is not None: - params.append(param.grad) - # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) - _fused_nan_to_num(params) - - if isinstance(model, FSDP) and self.fsdp_enabled: - total_norm = model.clip_grad_norm_(self.clip_norm) - else: - if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: - total_norm = model_ddp.module.clip_grad_norm_(self.clip_norm) - else: - total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) - - self._cur_state.update(total_norm) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/iter_speed.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/iter_speed.py deleted file mode 100644 index 14ce02dc71114a9fa9c739c922bcb7f5670d3d97..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/iter_speed.py +++ /dev/null @@ -1,82 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -from torch import Tensor - -from cosmos_transfer1.diffusion.training.callbacks.every_n import EveryN -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.distributed import rank0_only -from cosmos_transfer1.utils.model import Model -from cosmos_transfer1.utils.trainer import Trainer - - -class IterSpeed(EveryN): - """ - Args: - hit_thres (int): Number of iterations to wait before logging. - """ - - def __init__(self, *args, hit_thres: int = 5, **kwargs): - super().__init__(*args, **kwargs) - self.time = None - self.hit_counter = 0 - self.hit_thres = hit_thres - self.name = self.__class__.__name__ - self.last_hit_time = time.time() - - def on_training_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - if self.hit_counter < self.hit_thres: - log.info( - f"Iteration {iteration}: " - f"Hit counter: {self.hit_counter + 1}/{self.hit_thres} | " - f"Loss: {loss.item():.4f} | " - f"Time: {time.time() - self.last_hit_time:.2f}s" - ) - self.hit_counter += 1 - self.last_hit_time = time.time() - #! useful for large scale training and avoid oom crash in the first two iterations!!! - torch.cuda.synchronize() - return - super().on_training_step_end(model, data_batch, output_batch, loss, iteration) - - @rank0_only - def every_n_impl( - self, - trainer: Trainer, - model: Model, - data_batch: dict[str, Tensor], - output_batch: dict[str, Tensor], - loss: Tensor, - iteration: int, - ) -> None: - if self.time is None: - self.time = time.time() - return - cur_time = time.time() - iter_speed = (cur_time - self.time) / self.every_n / self.step_size - - log.info(f"{iteration} : iter_speed {iter_speed:.2f} seconds per iteration | Loss: {loss.item():.4f}") - - self.time = cur_time diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/low_precision.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/low_precision.py deleted file mode 100644 index 0074abab84a1ba2f8bc9c327fb676dcb39c82690..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/callbacks/low_precision.py +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from cosmos_transfer1.utils.callback import LowPrecisionCallback as BaseCallback -from cosmos_transfer1.utils.config import Config -from cosmos_transfer1.utils.model import Model -from cosmos_transfer1.utils.trainer import Trainer - - -class LowPrecisionCallback(BaseCallback): - """ - Config with non-primitive type makes it difficult to override the option. - The callback gets precision from model.precision instead. - """ - - def __init__(self, config: Config, trainer: Trainer, update_iter: int): - self.config = config - self.trainer = trainer - self.update_iter = update_iter - - def on_train_start(self, model: Model, iteration: int = 0) -> None: - assert model.precision in [ - torch.bfloat16, - torch.float16, - torch.half, - ], "LowPrecisionCallback must use a low precision dtype." - self.precision_type = model.precision diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/datasets/data_sources/item_dataset.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/datasets/data_sources/item_dataset.py deleted file mode 100644 index 0a4dcd91920a9ab1364af8a2a482495cb4bec0f9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/datasets/data_sources/item_dataset.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses - - -@dataclasses.dataclass -class ItemDatasetConfig: - path: str - length: int diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/datasets/dataset_utils.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/datasets/dataset_utils.py deleted file mode 100644 index 963e4c9de3d2e7958dbaa0526b284650671965b3..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/datasets/dataset_utils.py +++ /dev/null @@ -1,311 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Adapted from: -https://github.com/bytedance/IRASim/blob/main/dataset/dataset_util.py -""" - -import base64 -import math -import os -from io import BytesIO - -import numpy as np -import torch -import torch.distributed as dist -import torchvision.transforms.functional as F -from PIL import Image - - -def is_dist_avail_and_initialized(): - if not dist.is_available(): - return False - if not dist.is_initialized(): - return False - return True - - -def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return dist.get_rank() - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def b64_2_img(data: str): - image_b64 = base64.b64decode(data) - img = Image.open(BytesIO(image_b64)).convert("RGB") - return img - - -def get_continuous_action(d_acts, c_act_max, c_act_min, n_bins): - c_act_max = c_act_max.to(d_acts.device) - c_act_min = c_act_min.to(d_acts.device) - c_acts = d_acts / (n_bins - 1) * (c_act_max - c_act_min) + c_act_min - return c_acts - - -def alpha2rotm(a): - """Alpha euler angle to rotation matrix.""" - rotm = np.array([[1, 0, 0], [0, np.cos(a), -np.sin(a)], [0, np.sin(a), np.cos(a)]]) - return rotm - - -def beta2rotm(b): - """Beta euler angle to rotation matrix.""" - rotm = np.array([[np.cos(b), 0, np.sin(b)], [0, 1, 0], [-np.sin(b), 0, np.cos(b)]]) - return rotm - - -def gamma2rotm(c): - """Gamma euler angle to rotation matrix.""" - rotm = np.array([[np.cos(c), -np.sin(c), 0], [np.sin(c), np.cos(c), 0], [0, 0, 1]]) - return rotm - - -def euler2rotm(euler_angles): - """Euler angle (ZYX) to rotation matrix.""" - alpha = euler_angles[0] - beta = euler_angles[1] - gamma = euler_angles[2] - - rotm_a = alpha2rotm(alpha) - rotm_b = beta2rotm(beta) - rotm_c = gamma2rotm(gamma) - - rotm = rotm_c @ rotm_b @ rotm_a - - return rotm - - -def isRotm(R): - # Checks if a matrix is a valid rotation matrix. - # Forked from Andy Zeng - Rt = np.transpose(R) - shouldBeIdentity = np.dot(Rt, R) - I = np.identity(3, dtype=R.dtype) - n = np.linalg.norm(I - shouldBeIdentity) - return n < 1e-6 - - -def rotm2euler(R): - # Forked from: https://learnopencv.com/rotation-matrix-to-euler-angles/ - # R = Rz * Ry * Rx - assert isRotm(R) - sy = math.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0]) - singular = sy < 1e-6 - - if not singular: - x = math.atan2(R[2, 1], R[2, 2]) - y = math.atan2(-R[2, 0], sy) - z = math.atan2(R[1, 0], R[0, 0]) - else: - x = math.atan2(-R[1, 2], R[1, 1]) - y = math.atan2(-R[2, 0], sy) - z = 0 - - # (-pi , pi] - while x > np.pi: - x -= 2 * np.pi - while x <= -np.pi: - x += 2 * np.pi - while y > np.pi: - y -= 2 * np.pi - while y <= -np.pi: - y += 2 * np.pi - while z > np.pi: - z -= 2 * np.pi - while z <= -np.pi: - z += 2 * np.pi - return np.array([x, y, z]) - - -def get_converted_fp32_paths(deepspeed_ckpt_path): - deepspeed_ckpt_path = deepspeed_ckpt_path.rstrip("/") - ckpt_dir = os.path.dirname(deepspeed_ckpt_path) - ckpt_name = os.path.basename(deepspeed_ckpt_path) - fp32_ckpt_name = f"{ckpt_name}.fp32.pt" - converted_path = os.path.join(ckpt_dir, fp32_ckpt_name) - return converted_path - - -def quat2rotm(quat): - """Quaternion to rotation matrix. - - Args: - quat (4, numpy array): quaternion x, y, z, w - Returns: - rotm (3x3 numpy array): rotation matrix - """ - w = quat[3] - x = quat[0] - y = quat[1] - z = quat[2] - - s = w * w + x * x + y * y + z * z - - rotm = np.array( - [ - [1 - 2 * (y * y + z * z) / s, 2 * (x * y - z * w) / s, 2 * (x * z + y * w) / s], - [2 * (x * y + z * w) / s, 1 - 2 * (x * x + z * z) / s, 2 * (y * z - x * w) / s], - [2 * (x * z - y * w) / s, 2 * (y * z + x * w) / s, 1 - 2 * (x * x + y * y) / s], - ] - ) - - return rotm - - -class Resize_Preprocess: - def __init__(self, size): - """ - Initialize the preprocessing class with the target size. - Args: - size (tuple): The target height and width as a tuple (height, width). - """ - self.size = size - - def __call__(self, video_frames): - """ - Apply the transformation to each frame in the video. - Args: - video_frames (torch.Tensor): A tensor representing a batch of video frames. - Returns: - torch.Tensor: The transformed video frames. - """ - # Resize each frame in the video - resized_frames = torch.stack([F.resize(frame, self.size, antialias=True) for frame in video_frames]) - return resized_frames - - -class Preprocess: - def __init__(self, size): - self.size = size - - def __call__(self, clip): - clip = Preprocess.resize_scale(clip, self.size[0], self.size[1], interpolation_mode="bilinear") - return clip - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(size={self.size})" - - @staticmethod - def resize_scale(clip, target_height, target_width, interpolation_mode): - target_ratio = target_height / target_width - H = clip.size(-2) - W = clip.size(-1) - clip_ratio = H / W - if clip_ratio > target_ratio: - scale_ = target_width / W - else: - scale_ = target_height / H - return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) - - -class ToTensorVideo: - """ - Convert tensor data type from uint8 to float, divide value by 255.0 and - permute the dimensions of clip tensor - """ - - def __init__(self): - pass - - def __call__(self, clip): - """ - Args: - clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) - Return: - clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) - """ - return to_tensor(clip) - - def __repr__(self) -> str: - return self.__class__.__name__ - - -def to_tensor(clip): - """ - Convert tensor data type from uint8 to float, divide value by 255.0 and - permute the dimensions of clip tensor - Args: - clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) - Return: - clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) - """ - _is_tensor_video_clip(clip) - if not clip.dtype == torch.uint8: - raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) - # return clip.float().permute(3, 0, 1, 2) / 255.0 - return clip.float() / 255.0 - - -def _is_tensor_video_clip(clip): - if not torch.is_tensor(clip): - raise TypeError("clip should be Tensor. Got %s" % type(clip)) - - if not clip.ndimension() == 4: - raise ValueError("clip should be 4D. Got %dD" % clip.dim()) - - return True diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/functional/loss.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/functional/loss.py deleted file mode 100644 index 22d11006d4718183a5bbc5adc87ddad1eddf32fa..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/functional/loss.py +++ /dev/null @@ -1,135 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple, Union - -import torch - -from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul - - -def create_per_sample_loss_mask( - loss_masking_cfg: dict, - data_batch: dict, - x_shape: Tuple[int], - dtype: torch.dtype, - device: Union[str, torch.device] = "cuda", -): - """ - Creates a per-sample loss mask based on the given configuration and input data batch. - - This function generates a dictionary of loss masks for each specified key in the loss masking configuration. - For keys present in both the configuration and the data batch, the corresponding data batch value is used. - For keys present only in the configuration, a tensor of zeros with the specified shape is created. - Additionally, it computes loss mask weights for each key based on the configuration values and adjusts them - based on the presence of certain keys in the data batch, such as "skip_face" and "object_loss_map". - - Note: - - The original `loss_masking_cfg` and `data_batch` are not modified by this function. - - For image data, it is assumed that the channel is always the first dimension. - - `skip_face` is for face regions that should be skipped during training, the key is provided so that we can generate - diverse human and avoid collapse to a single face given certain prompts. The issue happens for getty projects, - where face distribution in the dataset is high unbalanced that single man face can be shown in more than 100+ images. - - Parameters: - loss_masking_cfg (dict): Configuration for loss masking, specifying which keys to include and their weights. - data_batch (dict): The batch of data containing actual data points and potential mask indicators like "skip_face". - x_shape (tuple): The shape of the input data, used to initialize zero masks for keys not in the data batch. - dtype (torch.dtype): The data type for the tensors in the loss masks. - device (str, optional): The device on which to create the tensors. Defaults to 'cuda'. - - Returns: - dict: A dictionary containing combined loss masks adjusted according to the `loss_masking_cfg` and `data_batch`. - - Raises: - AssertionError: If "skip_face" is not present in `data_batch`. - - Note: `create_combined_loss_mask` is assumed to be a separate function that combines individual loss masks into a - single mask or set of masks based on the given parameters. Its behavior should be documented separately. - """ - loss_mask_data: dict = {} - for key in loss_masking_cfg: - if key not in data_batch: - loss_mask_data[key] = torch.zeros((x_shape[0], 1, x_shape[2], x_shape[3]), device=device) - else: - loss_mask_data[key] = data_batch[key] - - if "skip_face" not in data_batch: - # When skip_face is not there in data_dict, use 0 as default. This will not skip any sample. - data_batch["skip_face"] = torch.zeros((x_shape[0],), dtype=dtype, device=device) - - loss_mask_weight: dict = {} - for k, v in loss_masking_cfg.items(): - loss_mask_weight[k] = torch.tensor(v, device=device).expand(data_batch["skip_face"].size()) - - if "human_face_mask" in loss_mask_weight: - loss_mask_weight["human_face_mask"] = (1 - data_batch["skip_face"]) * loss_mask_weight["human_face_mask"] - - if "object_loss_map" in data_batch: - loss_mask_weight["object_loss_map"] = torch.ones(data_batch["object_loss_map"].shape[0], device=device) - - return create_combined_loss_mask(loss_mask_data, x_shape, dtype, device, loss_mask_weight) - - -def create_combined_loss_mask(data, x_shape, dtype, device="cuda", loss_masking=None): - """ - Creates a combined loss mask from multiple input masks. - - This function combines several loss masks into a single mask. In regions where masks overlap, - the highest value is assigned. Non-overlapping regions are assigned a default value of 1. - Regions with a mask value of zero are explicitly zeroed out, which is essential for padded loss calculations. - - Example: - Given the following masks and weights: - mask1: [0, 1, 1, 1, 0, 0], weight: 2 - mask2: [1, 0, 1, 0, 0, 0], weight: 4 - mask3: [0, 1, 0, 0, 0, 0], weight: 0 - The resulting combined loss mask would be: - [4, 0, 4, 2, 1, 1] - - Parameters: - data (dict): Contains the loss masks and their weights. - x_shape (tuple): The shape of the output mask. - dtype: The data type for the output mask. - device: The device on which the output mask will be allocated. - loss_masking: The loss masking weight configuration. - - Returns: - torch.Tensor: The combined loss mask. - """ - - loss_mask = torch.ones(x_shape, dtype=dtype, device=device) - zero_mask = torch.ones(x_shape, dtype=dtype, device=device) - - if loss_masking: - for key in loss_masking: - # Repeat mask along channel's dimension. ndim=4 for images. - repeat_dims = (1, x_shape[1]) + tuple([1] * (data[key].ndim - 2)) - mask_key = torch.tile(data[key], dims=repeat_dims) - weight_key = loss_masking[key] - - # handle zero weight case - is_zero_weight = (weight_key == 0).float()[:, None, None, None] - zero_mask = zero_mask * ( - (1 - is_zero_weight) * torch.ones(x_shape, dtype=dtype, device=device) - + is_zero_weight * (1 - mask_key.bool().float()) - ) - - # calculate weights - no_mask_region = (mask_key.bool() == 0).float() - loss_mask = batch_mul(mask_key, weight_key) + batch_mul(no_mask_region, loss_mask) - - loss_mask_final = loss_mask * zero_mask - return loss_mask_final diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/functional/lr_scheduler.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/functional/lr_scheduler.py deleted file mode 100644 index 007fe1d01566a27fc9a53f4f3966ae34b395c33a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/functional/lr_scheduler.py +++ /dev/null @@ -1,178 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import numpy as np - -from cosmos_transfer1.utils import distributed, log - - -class TeroPolyScheduler: - def __init__( - self, - total_Mimg: int, - batch_size: int, - ref_Mimg: Optional[int] = None, - ref_batches: float = 70e3 / 1024, - max_lr_ratio: Optional[float] = 1.0, - min_lr_ratio: Optional[float] = None, - rampup_Mimg: float = 0, - rampdown_Mimg: int = 0, - verbosity_interval: int = 0, - formula: str = "poly", - poly_exp: float = 0.5, - ): - self.total_Mimg = total_Mimg - self.batch_size = batch_size * distributed.get_world_size() - self.ref_Mimg = ref_Mimg or ref_batches * batch_size / 1e6 - self.ref_batches = ref_batches - self.max_lr_ratio = max_lr_ratio - self.min_lr_ratio = min_lr_ratio - self.rampup_Mimg = rampup_Mimg - self.rampdown_Mimg = rampdown_Mimg - self.verbosity_interval = verbosity_interval - self.formula = formula - self.poly_exp = poly_exp - - self._model = None - - @property - def model(self): - return self._model - - @model.setter - def model(self, model): - self._model = model - - def schedule(self, n, **kwargs): - cur_Mimg = getattr(self.model, "sample_counter", 0) / 1e6 - - if self.formula == "constant": - lr = 1.0 - elif self.formula == "poly": - lr = max(cur_Mimg / self.ref_Mimg, 1e-8) ** -self.poly_exp - else: - raise ValueError(f'Invalid learning rate formula "{self.formula}"') - - if self.max_lr_ratio is not None: - lr = min(lr, self.max_lr_ratio) - if self.min_lr_ratio is not None: - lr = max(lr, self.min_lr_ratio) - - if self.rampup_Mimg > 0 and cur_Mimg < self.rampup_Mimg: - lr *= cur_Mimg / self.rampup_Mimg - if self.rampdown_Mimg > 0 and cur_Mimg > self.total_Mimg - self.rampdown_Mimg: - lr *= (self.total_Mimg - cur_Mimg) / self.rampdown_Mimg - - return lr - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaWarmUpCosineScheduler: - """ - A learning rate scheduler that combines warm-up with a cosine decay schedule for multiple cycles. - It supports different configurations for each cycle, including the number of warm-up steps, minimum - and maximum scaling factors for the learning rate. - - The scheduler is intended to be used with a base learning rate of 1.0, where the actual learning - rate at any step is the base learning rate multiplied by the scaling factor computed by the scheduler. - - Parameters: - warm_up_steps (list[int]): List of integers where each element represents the number of warm-up - steps for the corresponding cycle. - f_min (list[float]): List of the minimum scaling factors for each cycle after warm-up. - f_max (list[float]): List of the maximum scaling factors at the start and end of each cosine cycle. - f_start (list[float]): List of starting scaling factors for each warm-up phase. - cycle_lengths (list[int]): List of the total lengths of each cycle, including warm-up steps. - verbosity_interval (int, optional): Interval of training steps at which to print current step and - scaling factor information. Set to 0 by default to disable verbosity. - - Examples: - >>> scheduler = LambdaWarmUpCosineScheduler2( - warm_up_steps=[10, 10], - f_min=[0.1, 0.1], - f_max=[1.0, 1.0], - f_start=[0.01, 0.01], - cycle_lengths=[50, 50], - verbosity_interval=10) - >>> for step in range(100): - >>> lr_multiplier = scheduler(step) - >>> print(f"Step {step}: LR Multiplier = {lr_multiplier}") - """ - - def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): - assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) - self.lr_warm_up_steps = warm_up_steps - self.f_start = f_start - self.f_min = f_min - self.f_max = f_max - self.cycle_lengths = cycle_lengths - self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0.0 - self.verbosity_interval = verbosity_interval - - def find_in_interval(self, n): - interval = 0 - for cl in self.cum_cycles[1:]: - if n <= cl: - return interval - interval += 1 - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f - else: - t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) - t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) - self.last_f = f - return f - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaLinearScheduler(LambdaWarmUpCosineScheduler): - """ - Linear instead of cosine decay for the main part of the cycle. - """ - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") - - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f - else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( - self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] - ) - self.last_f = f - return f diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model.py deleted file mode 100644 index 4784c4e891c73067cd1a2efa79ae1f4df0b575e8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model.py +++ /dev/null @@ -1,576 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from statistics import NormalDist -from typing import Callable, Dict, Optional, Tuple, Union - -import numpy as np -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import DataType, VideoExtendCondition -from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig -from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.diffusion.training.models.model import DiffusionModel as BaseModel -from cosmos_transfer1.diffusion.training.models.model import _broadcast, broadcast_condition -from cosmos_transfer1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator -from cosmos_transfer1.utils import log, misc - - -@dataclass -class VideoDenoisePrediction: - x0: torch.Tensor # clean data prediction - eps: Optional[torch.Tensor] = None # noise prediction - logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty - net_in: Optional[torch.Tensor] = None # input to the network - net_x0_pred: Optional[torch.Tensor] = None # prediction of x0 from the network - xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in - x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent - - -def normalize_condition_latent(condition_latent): - """Normalize the condition latent tensor to have zero mean and unit variance - Args: - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W - """ - condition_latent_2D = rearrange(condition_latent, "b c t h w -> b c t (h w)") - mean = condition_latent_2D.mean(dim=-1) - std = condition_latent_2D.std(dim=-1) - # bct -> bct11 - mean = mean.unsqueeze(-1).unsqueeze(-1) - std = std.unsqueeze(-1).unsqueeze(-1) - condition_latent = (condition_latent - mean) / std - return condition_latent - - -class ExtendDiffusionModel(BaseModel): - def __init__(self, config): - super().__init__(config) - self.is_extend_model = True - - def get_data_and_condition( - self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None - ) -> Tuple[Tensor, Tensor, VideoExtendCondition]: - raw_state, latent_state, condition = super().get_data_and_condition(data_batch) - if condition.data_type == DataType.VIDEO: - if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: - latent_state = self.sample_tokens_start_from_p_or_i(latent_state) - condition = self.add_condition_video_indicator_and_video_input_mask( - latent_state, condition, num_condition_t=num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - condition = self.add_condition_pose(data_batch, condition) - log.debug(f"condition.data_type {condition.data_type}") - return raw_state, latent_state, condition - - def draw_augment_sigma_and_epsilon( - self, size: int, condition: VideoExtendCondition, p_mean: float, p_std: float, multiplier: float - ) -> Tensor: - is_video_batch = condition.data_type == DataType.VIDEO - del condition - batch_size = size[0] - epsilon = torch.randn(size, **self.tensor_kwargs) - - gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) - cdf_vals = np.random.uniform(size=(batch_size)) - samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] - - log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") - sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) - - sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) - epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) - return sigma_B, epsilon - - def augment_conditional_latent_frames( - self, - condition: VideoExtendCondition, - cfg_video_cond_bool: VideoCondBoolConfig, - gt_latent: Tensor, - condition_video_augment_sigma_in_inference: float = 0.001, - sigma: Tensor = None, - seed_inference: int = 1, - ) -> Union[VideoExtendCondition, Tensor]: - """This function is used to augment the condition input with noise - Args: - condition (VideoExtendCondition): condition object - condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor. - condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network. - cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config - gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - sigma (Tensor): noise level for the generation region - Returns: - VideoExtendCondition: updated condition object - condition_video_augment_sigma: sigma for the condition region, feed to the network - augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W - - """ - - if cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma": - # Training only, sample sigma for the condition region - augment_sigma, _ = self.draw_augment_sigma_and_epsilon( - gt_latent.shape, - condition, - cfg_video_cond_bool.augment_sigma_sample_p_mean, - cfg_video_cond_bool.augment_sigma_sample_p_std, - cfg_video_cond_bool.augment_sigma_sample_multiplier, - ) - noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) - - elif cfg_video_cond_bool.apply_corruption_to_condition_region == "noise_with_sigma_fixed": - # Inference only, use fixed sigma for the condition region - log.debug( - f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}, sigma={sigma.flatten()[0]}" - ) - assert ( - condition_video_augment_sigma_in_inference is not None - ), "condition_video_augment_sigma_in_inference should be provided" - augment_sigma = condition_video_augment_sigma_in_inference - - if augment_sigma >= sigma.flatten()[0]: - # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together. - # This is achieved by setting all region as `generation`, i.e. value=0 - log.debug("augment_sigma larger than sigma or other frame, remove condition") - condition.condition_video_indicator = condition.condition_video_indicator * 0 - - augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs) - - # Inference, use fixed seed - noise = misc.arch_invariant_rand( - gt_latent.shape, - torch.float32, - self.tensor_kwargs["device"], - seed_inference, - ) - else: - raise ValueError(f"does not support {cfg_video_cond_bool.apply_corruption_to_condition_region}") - - # Now apply the augment_sigma to the gt_latent - - augment_latent = gt_latent + noise * augment_sigma.view(-1, 1, 1, 1, 1) - _, _, c_in_augment, c_noise_augment = self.scaling(sigma=augment_sigma) - - if cfg_video_cond_bool.condition_on_augment_sigma: # model takes augment_sigma as input - if condition.condition_video_indicator.sum() > 0: # has condition frames - condition.condition_video_augment_sigma = c_noise_augment - else: # no condition frames - condition.condition_video_augment_sigma = torch.zeros_like(c_noise_augment) - - # Multiply the whole latent with c_in_augment - augment_latent_cin = batch_mul(augment_latent, c_in_augment) - - # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect - _, _, c_in, _ = self.scaling(sigma=sigma) - augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in) - - return condition, augment_latent_cin - - def drop_out_condition_region( - self, augment_latent: Tensor, noise_x: Tensor, cfg_video_cond_bool: VideoCondBoolConfig - ) -> Tensor: - """Use for CFG on input frames, we drop out the conditional region - There are two option: - 1. when we dropout, we set the region to be zero - 2. when we dropout, we set the region to be noise_x - """ - # Unconditional case, use for cfg - if cfg_video_cond_bool.cfg_unconditional_type == "zero_condition_region_condition_mask": - # Set the condition location input to be zero - augment_latent_drop = torch.zeros_like(augment_latent) - elif cfg_video_cond_bool.cfg_unconditional_type == "noise_x_condition_region": - # Set the condition location input to be noise_x, i.e., same as base model training - augment_latent_drop = noise_x - else: - raise NotImplementedError( - f"cfg_unconditional_type {cfg_video_cond_bool.cfg_unconditional_type} not implemented" - ) - return augment_latent_drop - - def denoise( - self, - noise_x: Tensor, - sigma: Tensor, - condition: VideoExtendCondition, - condition_video_augment_sigma_in_inference: float = 0.001, - seed_inference: int = 1, - ) -> VideoDenoisePrediction: - """ - Denoise the noisy input tensor. - - Args: - noise_x (Tensor): Noisy input tensor. - sigma (Tensor): Noise level. - condition (VideoExtendCondition): Condition for denoising. - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - Returns: - Tensor: Denoised output tensor. - """ - if condition.data_type == DataType.IMAGE: - pred = super().denoise(noise_x, sigma, condition) - log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) - return VideoDenoisePrediction( - x0=pred.x0, - eps=pred.eps, - logvar=pred.logvar, - xt=noise_x, - ) - else: - assert ( - condition.gt_latent is not None - ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" - gt_latent = condition.gt_latent - cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool - - condition_latent = gt_latent - - if cfg_video_cond_bool.normalize_condition_latent: - condition_latent = normalize_condition_latent(condition_latent) - - # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed - condition, augment_latent = self.augment_conditional_latent_frames( - condition, - cfg_video_cond_bool, - condition_latent, - condition_video_augment_sigma_in_inference, - sigma, - seed_inference=seed_inference, - ) - condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] - if parallel_state.get_context_parallel_world_size() > 1: - cp_group = parallel_state.get_context_parallel_group() - condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) - augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) - gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) - - if not condition.video_cond_bool: - # Unconditional case, drop out the condition region - augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) - - # Compose the model input with condition region (augment_latent) and generation region (noise_x) - new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x - # Call the abse model - denoise_pred = super().denoise(new_noise_xt, sigma, condition) - - x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 - if cfg_video_cond_bool.compute_loss_for_condition_region: - # We also denoise the conditional region - x0_pred = denoise_pred.x0 - else: - x0_pred = x0_pred_replaced - - return VideoDenoisePrediction( - x0=x0_pred, - eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), - logvar=denoise_pred.logvar, - net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), - net_x0_pred=denoise_pred.x0, - xt=new_noise_xt, - x0_pred_replaced=x0_pred_replaced, - ) - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - return_noise: bool = False, - ) -> Tensor | Tuple[Tensor, Tensor]: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. - If this feature is stablized, we could consider to move this function to the base model. - - Args: - condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. - num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half - - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames - return_noise (bool): return the initial noise or not, used for ODE pairs generation - """ - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - if is_image_batch: - log.debug("image batch, call base model generate_samples_from_batch") - return super().generate_samples_from_batch( - data_batch, - guidance=guidance, - seed=seed, - state_shape=state_shape, - n_sample=n_sample, - is_negative_prompt=is_negative_prompt, - num_steps=num_steps, - ) - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - else: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - assert condition_latent is not None, "condition_latent should be provided" - - x0_fn = self.get_x0_fn_from_batch_with_condition_latent( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - add_input_frames_guidance=add_input_frames_guidance, - seed_inference=seed, # Use for noise of augment sigma - ) - - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), torch.float32, self.tensor_kwargs["device"], seed - ) - * self.sde.sigma_max - ) - if self.net.is_context_parallel_enabled: - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - if return_noise: - if self.net.is_context_parallel_enabled: - x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - return samples, x_sigma_max / self.sde.sigma_max - - return samples - - def get_x0_fn_from_batch_with_condition_latent( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - seed_inference: int = 1, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. - - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - condition = self.add_condition_pose(data_batch, condition) - - uncondition.video_cond_bool = False if add_input_frames_guidance else True - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - uncondition = self.add_condition_pose(data_batch, uncondition) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) - else: - assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed_inference=seed_inference, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed_inference=seed_inference, - ).x0_pred_replaced - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def add_condition_video_indicator_and_video_input_mask( - self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None - ) -> VideoExtendCondition: - """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. - condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. - condition_video_input_mask will be concat with the input for the network. - Args: - latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W - condition (VideoExtendCondition): condition object - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - Returns: - VideoExtendCondition: updated condition object - """ - T = latent_state.shape[2] - latent_dtype = latent_state.dtype - condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( - latent_dtype - ) # 1 for condition region - if self.config.conditioner.video_cond_bool.condition_location == "first_n": - # Only in inference to decide the condition region - assert num_condition_t is not None, "num_condition_t should be provided" - assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" - log.info( - f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" - ) - condition_video_indicator[:, :, :num_condition_t] += 1.0 - elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": - # Only in training - num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max - assert ( - num_condition_t_max <= T - ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" - assert num_condition_t_max >= self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min - num_condition_t = torch.randint( - self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_min, - num_condition_t_max + 1, - (1,), - ).item() - condition_video_indicator[:, :, :num_condition_t] += 1.0 - - elif self.config.conditioner.video_cond_bool.condition_location == "random": - # Only in training - condition_rate = self.config.conditioner.video_cond_bool.random_conditon_rate - flag = torch.ones(1, 1, T, 1, 1, device=latent_state.device).type(latent_dtype) * condition_rate - condition_video_indicator = torch.bernoulli(flag).type(latent_dtype).to(latent_state.device) - else: - raise NotImplementedError( - f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" - ) - condition.gt_latent = latent_state - condition.condition_video_indicator = condition_video_indicator - - B, C, T, H, W = latent_state.shape - # Create additional input_mask channel, this will be concatenated to the input of the network - # See design doc section (Implementation detail A.1 and A.2) for visualization - ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - assert condition.video_cond_bool is not None, "video_cond_bool should be set" - - # The input mask indicate whether the input is conditional region or not - if condition.video_cond_bool: # Condition one given video frames - condition.condition_video_input_mask = ( - condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding - ) - else: # Unconditional case, use for cfg - condition.condition_video_input_mask = zeros_padding - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - else: - assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - return condition - - def add_condition_pose(self, data_batch: Dict, condition: VideoExtendCondition) -> VideoExtendCondition: - """Add pose condition to the condition object. For camera control model - Args: - data_batch (Dict): data batch, with key "plucker_embeddings", in shape B,T,C,H,W - latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W - condition (VideoExtendCondition): condition object - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - Returns: - VideoExtendCondition: updated condition object - """ - assert ( - "plucker_embeddings" in data_batch or "plucker_embeddings_downsample" in data_batch.keys() - ), f"plucker_embeddings should be in data_batch. only find {data_batch.keys()}" - plucker_embeddings = ( - data_batch["plucker_embeddings"] - if "plucker_embeddings_downsample" not in data_batch.keys() - else data_batch["plucker_embeddings_downsample"] - ) - condition.condition_video_pose = rearrange(plucker_embeddings, "b t c h w -> b c t h w").contiguous() - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - else: - assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - return condition - - def sample_tokens_start_from_p_or_i(self, latent_state: torch.Tensor) -> torch.Tensor: - """Sample the PPP... from the IPPP... sequence, only for video sequence - Args: - latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W - Returns: - torch.Tensor: sampled PPP tensor in shape B,C,T,H,W - """ - B, C, T, H, W = latent_state.shape - latent_dtype = latent_state.dtype - T_target = self.state_shape[1] - latent_state_sample = torch.zeros((B, C, T_target, H, W), dtype=latent_dtype, device=latent_state.device) - t_start = torch.randint(0, T - T_target + 1, (1,)) - # broadcast to other device - latent_state_sample = latent_state[:, :, t_start : t_start + T_target].contiguous() - if parallel_state.is_initialized(): - latent_state_sample = _broadcast(latent_state_sample, to_tp=True, to_cp=True) - - return latent_state_sample - - -@diffusion_fsdp_class_decorator -class FSDPExtendDiffusionModel(ExtendDiffusionModel): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model_multiview.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model_multiview.py deleted file mode 100644 index 87420db9c7d9b7723be3d80d1be5d9d3f180ee99..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model_multiview.py +++ /dev/null @@ -1,582 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from typing import Callable, Dict, Tuple, Union - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import DataType, VideoExtendCondition, ViewConditionedVideoExtendCondition -from cosmos_transfer1.diffusion.config.base.conditioner import VideoCondBoolConfig -from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.diffusion.training.models.extend_model import ( - ExtendDiffusionModel, - VideoDenoisePrediction, - normalize_condition_latent, -) -from cosmos_transfer1.diffusion.training.models.model import DiffusionModel, broadcast_condition -from cosmos_transfer1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator -from cosmos_transfer1.diffusion.training.models.model_multiview import deepcopy_no_copy_model -from cosmos_transfer1.utils import log - - -class MultiviewExtendDiffusionModel(ExtendDiffusionModel): - def __init__(self, config): - super().__init__(config) - self.n_views = config.n_views - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - encoded_state = self.vae.encode(state) - encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data - return encoded_state - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - decoded_state = self.vae.decode(latent / self.sigma_data) - decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - return decoded_state - - def compute_loss_with_epsilon_and_sigma( - self, - data_batch: dict[str, torch.Tensor], - x0_from_data_batch: torch.Tensor, - x0: torch.Tensor, - condition: CosmosCondition, - epsilon: torch.Tensor, - sigma: torch.Tensor, - ): - if self.is_image_batch(data_batch): - # Turn off CP - self.net.disable_context_parallel() - else: - if parallel_state.is_initialized(): - if parallel_state.get_context_parallel_world_size() > 1: - # Turn on CP - cp_group = parallel_state.get_context_parallel_group() - self.net.enable_context_parallel(cp_group) - log.debug("[CP] Split x0 and epsilon") - - x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - - x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) - epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) - - x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - output_batch, kendall_loss, pred_mse, edm_loss = super( - DiffusionModel, self - ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) - if not self.is_image_batch(data_batch): - if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: - kendall_loss *= parallel_state.get_context_parallel_world_size() - - return output_batch, kendall_loss, pred_mse, edm_loss - - def denoise( - self, - noise_x: Tensor, - sigma: Tensor, - condition: VideoExtendCondition, - condition_video_augment_sigma_in_inference: float = 0.001, - seed_inference: int = 1, - ) -> VideoDenoisePrediction: - """ - Denoise the noisy input tensor. - - Args: - noise_x (Tensor): Noisy input tensor. - sigma (Tensor): Noise level. - condition (VideoExtendCondition): Condition for denoising. - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - Returns: - Tensor: Denoised output tensor. - """ - if condition.data_type == DataType.IMAGE: - pred = super(DiffusionModel, self).denoise(noise_x, sigma, condition) - log.debug(f"hit image denoise, noise_x shape {noise_x.shape}, sigma shape {sigma.shape}", rank0_only=False) - return VideoDenoisePrediction( - x0=pred.x0, - eps=pred.eps, - logvar=pred.logvar, - xt=noise_x, - ) - else: - assert ( - condition.gt_latent is not None - ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}" - condition = deepcopy_no_copy_model(condition) - gt_latent = condition.gt_latent - cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool - - condition_latent = gt_latent - - if cfg_video_cond_bool.normalize_condition_latent: - condition_latent = normalize_condition_latent(condition_latent) - - # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed - condition, augment_latent = self.augment_conditional_latent_frames( - condition, - cfg_video_cond_bool, - condition_latent, - condition_video_augment_sigma_in_inference, - sigma, - seed_inference, - ) - condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1] - if parallel_state.get_context_parallel_world_size() > 1: - cp_group = parallel_state.get_context_parallel_group() - - condition_video_indicator = rearrange( - condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views - ) - augment_latent = rearrange(augment_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - gt_latent = rearrange(gt_latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - if getattr(condition, "view_indices_B_T", None) is not None: - view_indices_B_V_T = rearrange(condition.view_indices_B_T, "B (V T) -> (B V) T", V=self.n_views) - view_indices_B_V_T = split_inputs_cp(view_indices_B_V_T, seq_dim=1, cp_group=cp_group) - condition.view_indices_B_T = rearrange(view_indices_B_V_T, "(B V) T -> B (V T)", V=self.n_views) - condition_video_indicator = split_inputs_cp(condition_video_indicator, seq_dim=2, cp_group=cp_group) - augment_latent = split_inputs_cp(augment_latent, seq_dim=2, cp_group=cp_group) - gt_latent = split_inputs_cp(gt_latent, seq_dim=2, cp_group=cp_group) - - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views - ) - augment_latent = rearrange(augment_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - gt_latent = rearrange(gt_latent, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - if not condition.video_cond_bool: - # Unconditional case, drop out the condition region - augment_latent = self.drop_out_condition_region(augment_latent, noise_x, cfg_video_cond_bool) - # Compose the model input with condition region (augment_latent) and generation region (noise_x) - new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x - # Call the abse model - - denoise_pred = super(DiffusionModel, self).denoise(new_noise_xt, sigma, condition) - - x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0 - if cfg_video_cond_bool.compute_loss_for_condition_region: - # We also denoise the conditional region - x0_pred = denoise_pred.x0 - else: - x0_pred = x0_pred_replaced - - return VideoDenoisePrediction( - x0=x0_pred, - eps=batch_mul(noise_x - x0_pred, 1.0 / sigma), - logvar=denoise_pred.logvar, - net_in=batch_mul(1.0 / torch.sqrt(self.sigma_data**2 + sigma**2), new_noise_xt), - net_x0_pred=denoise_pred.x0, - xt=new_noise_xt, - x0_pred_replaced=x0_pred_replaced, - ) - - def add_condition_video_indicator_and_video_input_mask( - self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None - ) -> VideoExtendCondition: - """Add condition_video_indicator and condition_video_input_mask to the condition object for video conditioning. - condition_video_indicator is a binary tensor indicating the condition region in the latent state. 1x1xTx1x1 tensor. - condition_video_input_mask will be concat with the input for the network. - Args: - latent_state (torch.Tensor): latent state tensor in shape B,C,T,H,W - condition (VideoExtendCondition): condition object - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - Returns: - VideoExtendCondition: updated condition object - """ - T = latent_state.shape[2] - latent_dtype = latent_state.dtype - condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type( - latent_dtype - ) # 1 for condition region - - condition_video_indicator = rearrange( - condition_video_indicator, "B C (V T) H W -> (B V) C T H W", V=self.n_views - ) - if self.config.conditioner.video_cond_bool.condition_location == "first_n": - # Only in inference to decide the condition region - assert num_condition_t is not None, "num_condition_t should be provided" - assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" - log.info( - f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" - ) - - condition_video_indicator[:, :, :num_condition_t] += 1.0 - - elif self.config.conditioner.video_cond_bool.condition_location == "first_random_n": - # Only in training - num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max - assert ( - num_condition_t_max <= T - ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" - num_condition_t = torch.randint(0, num_condition_t_max + 1, (1,)).item() - condition_video_indicator[:, :, :num_condition_t] += 1.0 - - elif self.config.conditioner.video_cond_bool.condition_location == "first_cam": - # condition on first cam - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B V C T H W", V=self.n_views - ) - condition_video_indicator[:, 0, :, :, :, :] += 1.0 - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> (B V) C T H W", V=self.n_views - ) - elif self.config.conditioner.video_cond_bool.condition_location == "any_cam": - # condition on any n camera - n_cond_view = torch.randint( - self.config.conditioner.video_cond_bool.n_cond_view_min, - self.config.conditioner.video_cond_bool.n_cond_view_max + 1, - (1,), - ).item() - vids = torch.randperm(self.n_views) - cond_vids = vids[:n_cond_view] - - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B V C T H W", V=self.n_views - ) - - for vidx in cond_vids: - condition_video_indicator[:, vidx.item(), :, :, :, :] += 1.0 - condition_video_indicator = torch.clamp(condition_video_indicator, 0, 1) - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> (B V) C T H W", V=self.n_views - ) - elif self.config.conditioner.video_cond_bool.condition_location == "any_cam_and_random_n": - # condition on any n camera - n_cond_view = torch.randint( - self.config.conditioner.video_cond_bool.n_cond_view_min, - self.config.conditioner.video_cond_bool.n_cond_view_max + 1, - (1,), - ).item() - vids = torch.randperm(self.n_views) - cond_vids = vids[:n_cond_view] - - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B V C T H W", V=self.n_views - ) - - for vidx in cond_vids: - condition_video_indicator[:, vidx.item(), :, :, :, :] += 1.0 - # condition_video_indicator = torch.clamp(condition_video_indicator, 0, 1) - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> (B V) C T H W", V=self.n_views - ) - - num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max - assert ( - num_condition_t_max <= T - ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" - num_condition_t = torch.randint(0, num_condition_t_max + 1, (1,)).item() - condition_video_indicator[:, :, :num_condition_t] += 1.0 - condition_video_indicator = condition_video_indicator.clamp(max=1.0) - elif self.config.conditioner.video_cond_bool.condition_location.startswith("fixed_cam_and_first_n"): - # condition on a list of cameras specified through the string - cond_vids = [int(c) for c in self.config.conditioner.video_cond_bool.condition_location.split("_")[5:]] - - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B V C T H W", V=self.n_views - ) - - for vidx in cond_vids: - condition_video_indicator[:, vidx, :, :, :, :] += 1.0 - condition_video_indicator = torch.clamp(condition_video_indicator, 0, 1) - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> (B V) C T H W", V=self.n_views - ) - log.info( - f"condition_location fixed_cam_and_first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" - ) - condition_video_indicator[:, :, :num_condition_t] += 1.0 - condition_video_indicator = condition_video_indicator.clamp(max=1.0) - - elif self.config.conditioner.video_cond_bool.condition_location.startswith("fixed_cam"): - # condition on a list of cameras specified through the string - cond_vids = [int(c) for c in self.config.conditioner.video_cond_bool.condition_location.split("_")[2:]] - - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B V C T H W", V=self.n_views - ) - - for vidx in cond_vids: - condition_video_indicator[:, vidx, :, :, :, :] += 1.0 - condition_video_indicator = torch.clamp(condition_video_indicator, 0, 1) - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> (B V) C T H W", V=self.n_views - ) - elif self.config.conditioner.video_cond_bool.condition_location == "first_cam_and_random_n": - # condition on first cam - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B V C T H W", V=self.n_views - ) - condition_video_indicator[:, 0, :, :, :, :] += 1.0 - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> (B V) C T H W", V=self.n_views - ) - # and condition on first few cams - num_condition_t_max = self.config.conditioner.video_cond_bool.first_random_n_num_condition_t_max - assert ( - num_condition_t_max <= T - ), f"num_condition_t_max should be less than T, get {num_condition_t_max}, {T}" - num_condition_t = torch.randint(0, num_condition_t_max + 1, (1,)).item() - condition_video_indicator[:, :, :num_condition_t] += 1.0 - condition_video_indicator = condition_video_indicator.clamp(max=1.0) - elif self.config.conditioner.video_cond_bool.condition_location == "first_cam_and_first_n": - # condition on first cam - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B V C T H W", V=self.n_views - ) - condition_video_indicator[:, 0, :, :, :, :] += 1.0 - condition_video_indicator = rearrange( - condition_video_indicator, "B V C T H W -> (B V) C T H W", V=self.n_views - ) - assert num_condition_t is not None, "num_condition_t should be provided" - assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}" - log.info( - f"condition_location first_cam_and_first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}" - ) - condition_video_indicator[:, :, :num_condition_t] += 1.0 - condition_video_indicator = condition_video_indicator.clamp(max=1.0) - else: - raise NotImplementedError( - f"condition_location {self.config.conditioner.video_cond_bool.condition_location} not implemented; training={self.training}" - ) - - condition_video_indicator = rearrange( - condition_video_indicator, "(B V) C T H W -> B C (V T) H W", V=self.n_views - ) - - condition.gt_latent = latent_state - condition.condition_video_indicator = condition_video_indicator - - B, C, T, H, W = latent_state.shape - # Create additional input_mask channel, this will be concatenated to the input of the network - # See design doc section (Implementation detail A.1 and A.2) for visualization - ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device) - assert condition.video_cond_bool is not None, "video_cond_bool should be set" - - # The input mask indicate whether the input is conditional region or not - if condition.video_cond_bool: # Condition one given video frames - condition.condition_video_input_mask = ( - condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding - ) - else: # Unconditional case, use for cfg - condition.condition_video_input_mask = zeros_padding - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - else: - assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - return condition - - def get_x0_fn_from_batch_with_condition_latent( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - seed_inference: int = 1, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - Different from the base model, this function support condition latent as input, it will add the condition information into the condition and uncondition object. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. - - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - condition = self.add_condition_pose(data_batch, condition) - - uncondition.video_cond_bool = False if add_input_frames_guidance else True - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - uncondition = self.add_condition_pose(data_batch, uncondition) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) - else: - assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - ).x0_pred_replaced - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - return_noise: bool = False, - ) -> Tensor | Tuple[Tensor, Tensor]: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. - If this feature is stablized, we could consider to move this function to the base model. - - Args: - condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. - num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half - - add_input_frames_guidance (bool): add guidance to the input frames, used for cfg on input frames - """ - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - if is_image_batch: - log.debug("image batch, call base model generate_samples_from_batch") - return super().generate_samples_from_batch( - data_batch, - guidance=guidance, - seed=seed, - state_shape=state_shape, - n_sample=n_sample, - is_negative_prompt=is_negative_prompt, - num_steps=num_steps, - ) - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - else: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - assert condition_latent is not None, "condition_latent should be provided" - x0_fn = self.get_x0_fn_from_batch_with_condition_latent( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - add_input_frames_guidance=add_input_frames_guidance, - seed_inference=seed, # Use for noise of augment sigma - ) - - generator = torch.Generator(device=self.tensor_kwargs["device"]) - generator.manual_seed(seed) - x_sigma_max = ( - torch.randn(n_sample, *state_shape, **self.tensor_kwargs, generator=generator) * self.sde.sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max) - if self.net.is_context_parallel_enabled: - samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - if return_noise: - if self.net.is_context_parallel_enabled: - x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - return samples, x_sigma_max / self.sde.sigma_max - - return samples - - def get_data_and_condition( - self, data_batch: dict[str, Tensor], num_condition_t: Union[int, None] = None - ) -> Tuple[Tensor, Tensor, ViewConditionedVideoExtendCondition]: - if self.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i: - raise NotImplementedError( - "sample_tokens_start_from_p_or_i is not implemented for multiview extension diffusion model" - ) - raw_state, latent_state, condition = super().get_data_and_condition(data_batch, num_condition_t=num_condition_t) - if condition.data_type == DataType.VIDEO and "view_indices" in data_batch: - comp_factor = self.vae.temporal_compression_factor - # n_frames = data_batch['num_frames'] - view_indices = rearrange(data_batch["view_indices"], "B (V T) -> B V T", V=self.n_views) - view_indices_B_V_0 = view_indices[:, :, :1] - view_indices_B_V_1T = view_indices[:, :, 1:-1:comp_factor] - view_indices_B_V_T = torch.cat([view_indices_B_V_0, view_indices_B_V_1T], dim=-1) - condition.view_indices_B_T = rearrange(view_indices_B_V_T, "B V T -> B (V T)", V=self.n_views) - condition.data_n_views = self.n_views - log.debug(f"condition.data_n_views {self.n_views}") - return raw_state, latent_state, condition - - -@diffusion_fsdp_class_decorator -class FSDPExtendDiffusionModel(MultiviewExtendDiffusionModel): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model_multiview_ctrl.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model_multiview_ctrl.py deleted file mode 100644 index a8ead411c19f6e41197bcafbaee9793f8170a09b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/extend_model_multiview_ctrl.py +++ /dev/null @@ -1,501 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Dict, Optional, Tuple, Union - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import CosmosCondition, DataType, ViewConditionedVideoConditionerWithCtrl -from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS -from cosmos_transfer1.diffusion.inference.inference_utils import ( - merge_patches_into_video, - non_strict_load_model, - split_video_into_patches, -) -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.diffusion.training.models.extend_model_multiview import MultiviewExtendDiffusionModel -from cosmos_transfer1.diffusion.training.models.model import DiffusionModel as VideoDiffusionModel -from cosmos_transfer1.diffusion.training.models.model import _broadcast, broadcast_condition -from cosmos_transfer1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate - -IS_PREPROCESSED_KEY = "is_preprocessed" - - -class MultiVideoDiffusionModelWithCtrl(MultiviewExtendDiffusionModel): - def __init__(self, config, fsdp_checkpointer=None): - self.pixel_corruptor = None - if fsdp_checkpointer is not None: - return super().__init__(config, fsdp_checkpointer) - else: - return super().__init__(config) - - def build_model(self) -> torch.nn.ModuleDict: - log.info("Start creating base model") - base_model = super().build_model() - # initialize base model - config = self.config - self.load_base_model(base_model) - log.info("Done creating base model") - - log.info("Start creating ctrlnet model") - net = lazy_instantiate(self.config.net_ctrl) - conditioner = base_model.conditioner - logvar = base_model.logvar - # initialize controlnet encoder - model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) - - model.load_state_dict(base_model.state_dict(), strict=False) - - model.base_model = base_model - if not config.finetune_base_model: - model.base_model.requires_grad_(False) - log.critical("Only training ctrlnet model and keeping base model frozen") - else: - log.critical("Also training base model") - log.info("Done creating ctrlnet model") - - self.hint_key = self.config.hint_key["hint_key"] - return model - - @property - def base_net(self): - return self.model.base_model.net - - def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: - super().on_train_start(memory_format) - # self.base_model = self.base_model.to(memory_format=memory_format, **self.tensor_kwargs) - self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) - if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: - if parallel_state.sequence_parallel: - self.base_net.enable_sequence_parallel() - if hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile: # compatible with old config - # not tested yet - if torch.__version__ < "2.3": - log.warning( - "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" - "It's very likely there will be no significant speedup from torch.compile.\n" - "Please use at least 24.04 Pytorch container, or imaginaire4:v7 container." - ) - self.base_net = torch.compile(self.base_net, dynamic=False, disable=not self.config.use_torch_compile) - - def load_base_model(self, base_model) -> None: - config = self.config - if config.base_load_from is not None: - checkpoint_path = config.base_load_from["load_path"] - else: - checkpoint_path = "" - - if "*" in checkpoint_path: - # there might be better ways to decide if it's a converted tp checkpoint - mp_rank = parallel_state.get_model_parallel_group().rank() - checkpoint_path = checkpoint_path.replace("*", f"{mp_rank}") - - if checkpoint_path: - log.info(f"Loading base model checkpoint (local): {checkpoint_path}", False) - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}", False) - - if state_dict.get("ema") is not None: - # Copy the base model weights from ema model. - log.info("Copying ema to base model", False) - base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} - elif "model" in state_dict: - # Copy the base model weights from reg model. - log.warning("Using non-EMA base model", False) - base_state_dict = state_dict["model"] - else: - log.info("Loading from an EMA only model", False) - base_state_dict = state_dict - try: - base_model.load_state_dict(base_state_dict, strict=False) - except Exception: - log.critical("load model in non-strict mode", False) - log.critical(non_strict_load_model(base_model, base_state_dict), rank0_only=False) - log.info("Done loading the base model checkpoint.", False) - - def get_data_and_condition( - self, data_batch: dict[str, Tensor], **kwargs - ) -> Tuple[Tensor, ViewConditionedVideoConditionerWithCtrl]: - # process the control input - hint_key = self.config.hint_key["hint_key"] - is_image_batch = self.is_image_batch(data_batch) - _data = {hint_key: data_batch[hint_key]} - if IS_PREPROCESSED_KEY in data_batch: - _data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] - if not is_image_batch: - self._normalize_video_databatch_inplace(_data, input_key=hint_key) - # if it is an image batch, the control input is also image - if self.input_image_key in data_batch: - self._augment_image_dim_inplace(_data, input_key=hint_key) - data_batch[hint_key] = _data[hint_key] - - data_batch["hint_key"] = hint_key - - raw_state, latent_state, condition = super(MultiVideoDiffusionModelWithCtrl, self).get_data_and_condition( - data_batch, kwargs.get("num_condition_t", None) - ) - # if not torch.is_grad_enabled() and all(self.config.hint_mask): - use_multicontrol = ( - ("control_weight" in data_batch) - and not isinstance(data_batch["control_weight"], float) - and data_batch["control_weight"].shape[0] > 1 - ) - if use_multicontrol: # encode individual conditions separately - latent_hint = [] - num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 - for i in range(num_conditions): - cond_mask = [False] * num_conditions - cond_mask[i] = True - latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] - latent_hint = torch.cat(latent_hint) - else: - latent_hint = self.encode_latent(data_batch) - # copied from model.py - is_image_batch = self.is_image_batch(data_batch) - is_video_batch = not is_image_batch - # VAE has randomness. CP/TP group should have the same encoded output. - - # TODO: (qsh 2024-08-23) the following may not be necessary! - latent_hint = _broadcast(latent_hint, to_tp=True, to_cp=is_video_batch) - - # add extra conditions - data_batch["latent_hint"] = latent_hint - setattr(condition, hint_key, latent_hint) - setattr(condition, "base_model", self.model.base_model) - return raw_state, latent_state, condition - - def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: - x = data_batch[data_batch["hint_key"]] - if torch.is_grad_enabled() and self.pixel_corruptor is not None: - x = self.pixel_corruptor(x) - latent = [] - # control input goes through tokenizer, which always takes 3-input channels - num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension - if num_conditions > 1 and self.config.hint_dropout_rate > 0: - if torch.is_grad_enabled(): # during training, randomly dropout some conditions - cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate - if not cond_mask.any(): # make sure at least one condition is present - cond_mask[torch.randint(num_conditions, (1,)).item()] = True - elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used - cond_mask = self.config.hint_mask - else: - cond_mask = [True] * num_conditions - for idx in range(0, x.size(1), 3): - x_rgb = x[:, idx : idx + 3] # B C (V T) H W - if self.config.hint_key["grayscale"]: - x_rgb = x_rgb.mean(dim=1, keepdim=True).expand_as(x_rgb) - # if idx == 0: - # x_max = x_rgb - # else: - # x_max = torch.maximum(x_rgb, x_max) - if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image - x_rgb = torch.zeros_like(x_rgb) - latent.append(self.encode(x_rgb)) - # latent.append(self.encode(x_max)) - latent = torch.cat(latent, dim=1) - return latent - - def compute_loss_with_epsilon_and_sigma( - self, - data_batch: dict[str, Tensor], - x0_from_data_batch: Tensor, - x0: Tensor, - condition: CosmosCondition, - epsilon: Tensor, - sigma: Tensor, - ): - if self.is_image_batch(data_batch): - # Turn off CP - self.net.disable_context_parallel() - self.base_net.disable_context_parallel() - else: - if parallel_state.is_initialized(): - if parallel_state.get_context_parallel_world_size() > 1: - # Turn on CP - cp_group = parallel_state.get_context_parallel_group() - self.net.enable_context_parallel(cp_group) - self.base_net.enable_context_parallel(cp_group) - log.debug("[CP] Split hint_input") - hint_key = self.config.hint_key["hint_key"] - x_hint_raw = getattr(condition, hint_key) - x_hint = split_inputs_cp(x=x_hint_raw, seq_dim=2, cp_group=self.net.cp_group) - setattr(condition, hint_key, x_hint) - return super(MultiVideoDiffusionModelWithCtrl, self).compute_loss_with_epsilon_and_sigma( - data_batch, x0_from_data_batch, x0, condition, epsilon, sigma - ) - - def get_x0_fn_from_batch_with_condition_latent( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - add_input_frames_guidance: bool = False, - seed_inference: int = 1, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. - - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - # data_batch should be the one processed by self.get_data_and_condition - try: - if "control_weight" not in data_batch and self.control_weight is not None: - data_batch["control_weight"] = self.control_weight - log.info(f"Setting control weight to {self.control_weight}") - else: - log.info(f"Control weight is {data_batch['control_weight']}") - except Exception: - pass - - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - if condition.data_type == DataType.VIDEO and "view_indices" in data_batch: - comp_factor = self.vae.temporal_compression_factor - # n_frames = data_batch['num_frames'] - view_indices = rearrange(data_batch["view_indices"], "B (V T) -> B V T", V=self.n_views) - view_indices_B_V_0 = view_indices[:, :, :1] - view_indices_B_V_1T = view_indices[:, :, 1:-1:comp_factor] - view_indices_B_V_T = torch.cat([view_indices_B_V_0, view_indices_B_V_1T], dim=-1) - condition.view_indices_B_T = rearrange(view_indices_B_V_T, "B V T -> B (V T)", V=self.n_views) - condition.data_n_views = self.n_views - uncondition.view_indices_B_T = condition.view_indices_B_T - uncondition.data_n_views = self.n_views - - if self.is_image_batch(data_batch): - condition.data_type = DataType.IMAGE - uncondition.data_type = DataType.IMAGE - else: - if condition_latent is None: - condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) - num_condition_t = 0 - condition_video_augment_sigma_in_inference = 1000 - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - condition = self.add_condition_pose(data_batch, condition) - - uncondition.video_cond_bool = True # Not do cfg on condition frames - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - uncondition = self.add_condition_pose(data_batch, uncondition) - - # Add extra conditions for ctrlnet. - latent_hint = data_batch["latent_hint"] - hint_key = data_batch["hint_key"] - setattr(condition, hint_key, latent_hint) - if "use_none_hint" in data_batch and data_batch["use_none_hint"]: - setattr(uncondition, hint_key, None) - else: - setattr(uncondition, hint_key, latent_hint) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized() and not self.is_image_batch(data_batch): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) - cp_group = parallel_state.get_context_parallel_group() - latent_hint = getattr(condition, hint_key) - if latent_hint is not None: - latent_hint = split_inputs_cp(latent_hint, seq_dim=2, cp_group=cp_group) - setattr(condition, hint_key, latent_hint) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint) - - setattr(condition, "base_model", self.model.base_model) - setattr(uncondition, "base_model", self.model.base_model) - if hasattr(self, "hint_encoders"): - self.model.net.hint_encoders = self.hint_encoders - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - ).x0_pred_replaced - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - solver_option: COMMON_SOLVER_OPTIONS = "2ab", - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - add_input_frames_guidance: bool = False, - return_noise: bool = False, - ) -> Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. - If this feature is stablized, we could consider to move this function to the base model. - - Args: - condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. - num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half - """ - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - if is_image_batch: - log.debug("image batch, call base model generate_samples_from_batch") - return super(MultiVideoDiffusionModelWithCtrl, self).generate_samples_from_batch( - data_batch, - guidance=guidance, - seed=seed, - state_shape=state_shape, - n_sample=n_sample, - is_negative_prompt=is_negative_prompt, - num_steps=num_steps, - ) - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - else: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - # assert condition_latent is not None, "condition_latent should be provided" - - # if self.net.is_context_parallel_enabled: - # data_batch["latent_hint"] = split_inputs_cp(x=data_batch["latent_hint"], seq_dim=2, cp_group=self.net.cp_group) - - x0_fn = self.get_x0_fn_from_batch_with_condition_latent( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - add_input_frames_guidance=add_input_frames_guidance, - seed_inference=seed, - ) - - if sigma_max is None: - sigma_max = self.sde.sigma_max - - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = _broadcast(x_sigma_max, to_tp=True, to_cp=True) - x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - samples = self.sampler( - x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option - ) - - if self.net.is_context_parallel_enabled: - samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - return samples - - @torch.no_grad() - def validation_step( - self, data: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """ - save generated videos - """ - raw_data, x0, condition = self.get_data_and_condition(data) - guidance = data["guidance"] - sigma_max = data["sigma_max"] - is_negative_prompt = data["is_negative_prompt"] - data = misc.to(data, **self.tensor_kwargs) - x_sigma_max = None - if sigma_max is not None: - x_sigma_max = self.get_x_from_clean(x0, sigma_max) - sample = self.generate_samples_from_batch( - data, - guidance=guidance, - # make sure no mismatch and also works for cp - state_shape=x0.shape[1:], - n_sample=x0.shape[0], - x_sigma_max=x_sigma_max, - sigma_max=sigma_max, - is_negative_prompt=is_negative_prompt, - ) - sample = self.decode(sample) - gt = raw_data - hint = data[data["hint_key"]][:, :3] - result = torch.cat([hint, sample], dim=3) - gt = torch.cat([hint, gt], dim=3) - caption = data["ai_caption"] - return {"gt": gt, "result": result, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) - - -@diffusion_fsdp_class_decorator -class FSDPMultiVideoDiffusionModelWithCtrl(MultiVideoDiffusionModelWithCtrl): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model.py deleted file mode 100644 index ff06cd4b11f31d87667391bc65e4955b56e67233..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model.py +++ /dev/null @@ -1,660 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union - -import amp_C -import torch -from apex.multi_tensor_apply import multi_tensor_applier -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from torch.distributed import broadcast_object_list, get_process_group_ranks -from torch.distributed.utils import _verify_param_shape_across_processes - -from cosmos_transfer1.diffusion.conditioner import BaseVideoCondition, DataType -from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.diffusion.training.models.model_image import CosmosCondition -from cosmos_transfer1.diffusion.training.models.model_image import DiffusionModel as ImageModel -from cosmos_transfer1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator -from cosmos_transfer1.utils import distributed, log, misc - -l2_norm_impl = amp_C.multi_tensor_l2norm -multi_tensor_scale_impl = amp_C.multi_tensor_scale - -# key to check if the video data is normalized or image data is converted to video data -# to avoid apply normalization or augment image dimension multiple times -# It is due to we do not have normalization and augment image dimension in the dataloader and move it to the model -IS_PREPROCESSED_KEY = "is_preprocessed" - - -def robust_broadcast(tensor: torch.Tensor, src: int, pg, is_check_shape: bool = False) -> torch.Tensor: - """ - Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. - - Args: - tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). - src (int): The source rank for the broadcast. Defaults to 0. - - Returns: - torch.Tensor: The broadcasted tensor on all ranks. - """ - # First, broadcast the shape of the tensor - if distributed.get_rank() == src: - shape = torch.tensor(tensor.shape).cuda() - else: - shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() - if is_check_shape: - _verify_param_shape_across_processes(pg, [shape]) - torch.distributed.broadcast(shape, src, group=pg) - - # Resize the tensor on non-src ranks if necessary - if distributed.get_rank() != src: - tensor = tensor.new_empty(shape.tolist()).type_as(tensor) - - # Now broadcast the tensor data - torch.distributed.broadcast(tensor, src, group=pg) - - return tensor - - -def _broadcast(item: torch.Tensor | str | None, to_tp: bool = True, to_cp: bool = True) -> torch.Tensor | str | None: - """ - Broadcast the item from the minimum rank in the specified group(s). - Since global rank = tp_rank + cp_rank * tp_size + ... - First broadcast in the tp_group and then in the cp_group will - ensure that the item is broadcasted across ranks in cp_group and tp_group. - - Parameters: - - item: The item to broadcast (can be a torch.Tensor, str, or None). - - to_tp: Whether to broadcast to the tensor model parallel group. - - to_cp: Whether to broadcast to the context parallel group. - """ - if not parallel_state.is_initialized(): - return item - tp_group = parallel_state.get_tensor_model_parallel_group() - cp_group = parallel_state.get_context_parallel_group() - - to_tp = to_tp and parallel_state.get_tensor_model_parallel_world_size() > 1 - to_cp = to_cp and parallel_state.get_context_parallel_world_size() > 1 - - if to_tp: - min_tp_rank = min(get_process_group_ranks(tp_group)) - - if to_cp: - min_cp_rank = min(get_process_group_ranks(cp_group)) - - if isinstance(item, torch.Tensor): # assume the device is cuda - # log.info(f"{item.shape}", rank0_only=False) - if to_tp: - # torch.distributed.broadcast(item, min_tp_rank, group=tp_group) - item = robust_broadcast(item, min_tp_rank, tp_group) - if to_cp: - # torch.distributed.broadcast(item, min_cp_rank, group=cp_group) - item = robust_broadcast(item, min_cp_rank, cp_group) - elif item is not None: - broadcastable_list = [item] - if to_tp: - # log.info(f"{broadcastable_list}", rank0_only=False) - broadcast_object_list(broadcastable_list, min_tp_rank, group=tp_group) - if to_cp: - broadcast_object_list(broadcastable_list, min_cp_rank, group=cp_group) - - item = broadcastable_list[0] - return item - - -def broadcast_condition(condition: BaseVideoCondition, to_tp: bool = True, to_cp: bool = True) -> BaseVideoCondition: - condition_kwargs = {} - for k, v in condition.to_dict().items(): - if isinstance(v, torch.Tensor): - assert not v.requires_grad, f"{k} requires gradient. the current impl does not support it" - condition_kwargs[k] = _broadcast(v, to_tp=to_tp, to_cp=to_cp) - condition = type(condition)(**condition_kwargs) - return condition - - -class DiffusionModel(ImageModel): - def __init__(self, config): - super().__init__(config) - # Initialize trained_data_record with defaultdict, key: image, video, iteration - self.trained_data_record = { - "image": 0, - "video": 0, - "iteration": 0, - } - if parallel_state.is_initialized(): - self.data_parallel_size = parallel_state.get_data_parallel_world_size() - else: - self.data_parallel_size = 1 - - if self.config.adjust_video_noise: - self.video_noise_multiplier = math.sqrt(self.state_shape[1]) - else: - self.video_noise_multiplier = 1.0 - - def setup_data_key(self) -> None: - self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model - self.input_image_key = self.config.input_image_key - - def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: - """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. - Another comes from a dataloader which we by default assumes as video_data for video model training. - """ - is_image = self.input_image_key in data_batch - is_video = self.input_data_key in data_batch - assert ( - is_image != is_video - ), "Only one of the input_image_key or input_data_key should be present in the data_batch." - return is_image - - def draw_training_sigma_and_epsilon(self, size: int, condition: BaseVideoCondition) -> Tensor: - sigma_B, epsilon = super().draw_training_sigma_and_epsilon(size, condition) - is_video_batch = condition.data_type == DataType.VIDEO - multiplier = self.video_noise_multiplier if is_video_batch else 1 - sigma_B = _broadcast(sigma_B * multiplier, to_tp=True, to_cp=is_video_batch) - epsilon = _broadcast(epsilon, to_tp=True, to_cp=is_video_batch) - return sigma_B, epsilon - - @torch.no_grad() - def validation_step( - self, data: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """ - save generated videos - """ - raw_data, x0, condition = self.get_data_and_condition(data) - guidance = data["guidance"] - data = misc.to(data, **self.tensor_kwargs) - sample = self.generate_samples_from_batch( - data, - guidance=guidance, - # make sure no mismatch and also works for cp - state_shape=x0.shape[1:], - n_sample=x0.shape[0], - ) - sample = self.decode(sample) - gt = raw_data - caption = data["ai_caption"] - return {"gt": gt, "result": sample, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) - - def training_step(self, data_batch: Dict[str, Tensor], iteration: int) -> Tuple[Dict[str, Tensor] | Tensor]: - input_key = self.input_data_key # by default it is video key - if self.is_image_batch(data_batch): - input_key = self.input_image_key - batch_size = data_batch[input_key].shape[0] - self.trained_data_record["image" if self.is_image_batch(data_batch) else "video"] += ( - batch_size * self.data_parallel_size - ) - self.trained_data_record["iteration"] += 1 - return super().training_step(data_batch, iteration) - - def state_dict(self) -> Dict[str, Any]: - state_dict = super().state_dict() - state_dict["trained_data_record"] = self.trained_data_record - return state_dict - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): - if "trained_data_record" in state_dict and hasattr(self, "trained_data_record"): - trained_data_record = state_dict.pop("trained_data_record") - if trained_data_record: - assert set(trained_data_record.keys()) == set(self.trained_data_record.keys()) - for k, v in trained_data_record.items(): - self.trained_data_record[k] = v - else: - log.warning("trained_data_record not found in the state_dict.") - return super().load_state_dict(state_dict, strict, assign) - - def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: - """ - Normalizes video data in-place on a CUDA device to reduce data loading overhead. - - This function modifies the video data tensor within the provided data_batch dictionary - in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1]. - - Warning: - A warning is issued if the data has not been previously normalized. - - Args: - data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key. - This tensor is expected to be on a CUDA device and have dtype of torch.uint8. - - Side Effects: - Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place. - - Note: - This operation is performed directly on the CUDA device to avoid the overhead associated - with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device - and has the correct dtype (torch.uint8) to avoid unexpected behaviors. - """ - input_key = self.input_data_key if input_key is None else input_key - # only handle video batch - if input_key in data_batch: - # Check if the data has already been normalized and avoid re-normalizing - if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: - assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." - assert torch.all( - (data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001) - ), f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" - else: - assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." - data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 - data_batch[IS_PREPROCESSED_KEY] = True - - def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: - input_key = self.input_image_key if input_key is None else input_key - if input_key in data_batch: - # Check if the data has already been augmented and avoid re-augmenting - if IS_PREPROCESSED_KEY in data_batch and data_batch[IS_PREPROCESSED_KEY] is True: - assert ( - data_batch[input_key].shape[2] == 1 - ), f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" - return - else: - data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() - data_batch[IS_PREPROCESSED_KEY] = True - - def get_data_and_condition(self, data_batch: dict[str, Tensor]) -> Tuple[Tensor, BaseVideoCondition]: - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - input_key = self.input_data_key # by default it is video key - is_image_batch = self.is_image_batch(data_batch) - is_video_batch = not is_image_batch - - # Broadcast data and condition across TP and CP groups. - # sort keys to make sure the order is same, IMPORTANT! otherwise, nccl will hang! - local_keys = sorted(list(data_batch.keys())) - # log.critical(f"all keys {local_keys}", rank0_only=False) - for key in local_keys: - data_batch[key] = _broadcast(data_batch[key], to_tp=True, to_cp=is_video_batch) - - if is_image_batch: - input_key = self.input_image_key - - # Latent state - raw_state = data_batch[input_key] - latent_state = self.encode(raw_state).contiguous() - - # Condition - condition = self.conditioner(data_batch) - if is_image_batch: - condition.data_type = DataType.IMAGE - else: - condition.data_type = DataType.VIDEO - - # VAE has randomness. CP/TP group should have the same encoded output. - - latent_state = _broadcast(latent_state, to_tp=True, to_cp=is_video_batch) - condition = broadcast_condition(condition, to_tp=True, to_cp=is_video_batch) - - return raw_state, latent_state, condition - - def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: - super().on_train_start(memory_format) - if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: - sequence_parallel = getattr(parallel_state, "sequence_parallel", False) - if sequence_parallel: - self.net.enable_sequence_parallel() - - def compute_loss_with_epsilon_and_sigma( - self, - data_batch: dict[str, torch.Tensor], - x0_from_data_batch: torch.Tensor, - x0: torch.Tensor, - condition: CosmosCondition, - epsilon: torch.Tensor, - sigma: torch.Tensor, - ): - if self.is_image_batch(data_batch): - # Turn off CP - self.net.disable_context_parallel() - else: - if parallel_state.is_initialized(): - if parallel_state.get_context_parallel_world_size() > 1: - # Turn on CP - cp_group = parallel_state.get_context_parallel_group() - self.net.enable_context_parallel(cp_group) - log.debug("[CP] Split x0 and epsilon") - x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) - epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) - - output_batch, kendall_loss, pred_mse, edm_loss = super().compute_loss_with_epsilon_and_sigma( - data_batch, x0_from_data_batch, x0, condition, epsilon, sigma - ) - if not self.is_image_batch(data_batch): - if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: - kendall_loss *= parallel_state.get_context_parallel_world_size() - - return output_batch, kendall_loss, pred_mse, edm_loss - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) - else: - assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise(noise_x, sigma, condition).x0 - uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 - raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) - if "guided_image" in data_batch: - # replacement trick that enables inpainting with base model - assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" - guide_image = data_batch["guided_image"] - guide_mask = data_batch["guided_mask"] - raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 - return raw_x0 - - return x0_fn - - def get_x_from_clean( - self, - in_clean_img: torch.Tensor, - sigma_max: float | None, - seed: int = 1, - ) -> Tensor: - """ - in_clean_img (torch.Tensor): input clean image for image-to-image/video-to-video by adding noise then denoising - sigma_max (float): maximum sigma applied to in_clean_image for image-to-image/video-to-video - """ - if in_clean_img is None: - return None - generator = torch.Generator(device=self.tensor_kwargs["device"]) - generator.manual_seed(seed) - noise = torch.randn(*in_clean_img.shape, **self.tensor_kwargs, generator=generator) - if sigma_max is None: - sigma_max = self.sde.sigma_max - x_sigma_max = in_clean_img + noise * sigma_max - return x_sigma_max - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - solver_option: COMMON_SOLVER_OPTIONS = "2ab", - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - return_noise: bool = False, - ) -> Tensor | Tuple[Tensor, Tensor]: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Args: - data_batch (dict): raw data batch draw from the training data loader. - iteration (int): Current iteration number. - guidance (float): guidance weights - seed (int): random seed - state_shape (tuple): shape of the state, default to self.state_shape if not provided - n_sample (int): number of samples to generate - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - num_steps (int): number of steps for the diffusion process - solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) - return_noise (bool): return the initial noise or not, used for ODE pairs generation - """ - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - - x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) - - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * self.sde.sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler( - x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option - ) - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - if return_noise: - if self.net.is_context_parallel_enabled: - x_sigma_max = cat_outputs_cp(x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - return samples, x_sigma_max / self.sde.sigma_max - - return samples - - def on_after_backward(self, iteration: int = 0): - finalize_model_grads([self]) - - def get_grad_norm( - self, - norm_type: Union[int, float] = 2, - filter_fn: Callable[[str, torch.nn.Parameter], bool] | None = None, - ) -> float: - """Calculate the norm of gradients, handling model parallel parameters. - - This function is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ - with added functionality to handle model parallel parameters. - - Args: - norm_type (float or int): Type of norm to use. Can be 2 for L2 norm. - 'inf' for infinity norm is not supported. - filter_fn (callable, optional): Function to filter parameters for norm calculation. - Takes parameter name and parameter as input, returns True if this parameter is sharded else False. - - Returns: - float: Total norm of the parameters (viewed as a single vector). - - Note: - - Uses NVIDIA's multi-tensor applier for efficient norm calculation. - - Handles both model parallel and non-model parallel parameters separately. - - Currently only supports L2 norm (norm_type = 2). - """ - # Get model parallel group if parallel state is initialized - if parallel_state.is_initialized(): - model_parallel_group = parallel_state.get_model_parallel_group() - else: - model_parallel_group = None - - # Default filter function to identify tensor parallel parameters - if filter_fn is None: - - def is_tp(name, param): - return ( - any(key in name for key in ["to_q.0", "to_k.0", "to_v.0", "to_out.0", "layer1", "layer2"]) - and "_extra_state" not in name - ) - - filter_fn = is_tp - - # Separate gradients into model parallel and non-model parallel - without_mp_grads_for_norm = [] - with_mp_grads_for_norm = [] - for name, param in self.named_parameters(): - if param.grad is not None: - if filter_fn(name, param): - with_mp_grads_for_norm.append(param.grad.detach()) - else: - without_mp_grads_for_norm.append(param.grad.detach()) - - # Only L2 norm is currently supported - if norm_type != 2.0: - raise NotImplementedError(f"Norm type {norm_type} is not supported. Only L2 norm (2.0) is implemented.") - - # Calculate L2 norm using NVIDIA's multi-tensor applier - dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") - - # Calculate norm for non-model parallel gradients - without_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") - if without_mp_grads_for_norm: - without_mp_grad_norm, _ = multi_tensor_applier( - l2_norm_impl, - dummy_overflow_buf, - [without_mp_grads_for_norm], - False, # no per-parameter norm - ) - - # Calculate norm for model parallel gradients - with_mp_grad_norm = torch.tensor([0], dtype=torch.float, device="cuda") - if with_mp_grads_for_norm: - with_mp_grad_norm, _ = multi_tensor_applier( - l2_norm_impl, - dummy_overflow_buf, - [with_mp_grads_for_norm], - False, # no per-parameter norm - ) - - # Square the norms as we'll be summing across model parallel GPUs - total_without_mp_norm = without_mp_grad_norm**2 - total_with_mp_norm = with_mp_grad_norm**2 - - # Sum across all model-parallel GPUs - torch.distributed.all_reduce(total_with_mp_norm, op=torch.distributed.ReduceOp.SUM, group=model_parallel_group) - - # Combine norms from model parallel and non-model parallel gradients - total_norm = (total_with_mp_norm.item() + total_without_mp_norm.item()) ** 0.5 - - return total_norm - - def clip_grad_norm_(self, max_norm: float): - """ - This function performs gradient clipping to prevent exploding gradients. - It calculates the total norm of the gradients, and if it exceeds the - specified max_norm, scales the gradients down proportionally. - - Args: - max_norm (float): The maximum allowed norm for the gradients. - - Returns: - torch.Tensor: The total norm of the gradients before clipping. - - Note: - This implementation uses NVIDIA's multi-tensor applier for efficiency. - """ - # Collect gradients from all parameters that require gradients - grads = [] - for param in self.parameters(): - if param.grad is not None: - grads.append(param.grad.detach()) - - # Calculate the total norm of the gradients - total_norm = self.get_grad_norm() - - # Compute the clipping coefficient - clip_coeff = max_norm / (total_norm + 1.0e-6) - - # Apply gradient clipping if the total norm exceeds max_norm - if clip_coeff < 1.0: - dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") - # Apply the scaling to the gradients using multi_tensor_applier for efficiency - multi_tensor_applier(multi_tensor_scale_impl, dummy_overflow_buf, [grads, grads], clip_coeff) - - return torch.tensor([total_norm]) - - -def _allreduce_layernorm_grads(model: List[torch.nn.Module]): - """ - All-reduce the following layernorm grads: - - When tensor parallel is enabled, all-reduce grads of QK-layernorm - - When sequence parallel, all-reduce grads of AdaLN, t_embedder, additional_timestamp_embedder, - and affline_norm. - """ - sequence_parallel = getattr(parallel_state, "sequence_parallel", False) - - if parallel_state.get_tensor_model_parallel_world_size() > 1: - grads = [] - for model_chunk in model: - for name, param in model_chunk.named_parameters(): - if not param.requires_grad: - continue - - if "to_q.1" in name or "to_k.1" in name: # TP # Q-layernorm # K-layernorm - grad = param.grad - if grad is not None: - grads.append(grad.data) - - if sequence_parallel: # TP + SP - if ( - "t_embedder" in name - or "adaLN_modulation" in name - or "additional_timestamp_embedder" in name - or "affline_norm" in name - or "input_hint_block" in name - or "zero_blocks" in name - ): - grad = param.grad - if grad is not None: - grads.append(grad.data) - - if grads: - coalesced = _flatten_dense_tensors(grads) - torch.distributed.all_reduce(coalesced, group=parallel_state.get_tensor_model_parallel_group()) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) - - -def finalize_model_grads(model: List[torch.nn.Module]): - """ - All-reduce layernorm grads for tensor/sequence parallelism. - Reference implementation: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/finalize_model_grads.py#L99 - """ - - _allreduce_layernorm_grads(model) - - -@diffusion_fsdp_class_decorator -class FSDPDiffusionModel(DiffusionModel): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_ctrl.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_ctrl.py deleted file mode 100644 index 08c85549b042c082b7f9873e3d8c88d0757c20ff..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_ctrl.py +++ /dev/null @@ -1,754 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Dict, Optional, Tuple, Type, TypeVar, Union - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.conditioner import CosmosCondition, DataType, VideoConditionerWithCtrl -from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS -from cosmos_transfer1.diffusion.inference.inference_utils import ( - merge_patches_into_video, - non_strict_load_model, - split_video_into_patches, -) -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.diffusion.training.models.extend_model import ExtendDiffusionModel as ExtendVideoDiffusionModel -from cosmos_transfer1.diffusion.training.models.model import DiffusionModel as VideoDiffusionModel -from cosmos_transfer1.diffusion.training.models.model import _broadcast, broadcast_condition -from cosmos_transfer1.diffusion.training.models.model_image import diffusion_fsdp_class_decorator -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.lazy_config import instantiate - -T = TypeVar("T") -IS_PREPROCESSED_KEY = "is_preprocessed" - - -def ctrlnet_decorator(base_class: Type[T]) -> Type[T]: - class CtrlNetModel(base_class): - def __init__(self, config, fsdp_checkpointer=None): - if fsdp_checkpointer is not None: - return super().__init__(config, fsdp_checkpointer) - else: - return super().__init__(config) - - def build_model(self) -> torch.nn.ModuleDict: - log.info("Start creating base model") - base_model = super().build_model() - # initialize base model - config = self.config - self.load_base_model(base_model) - log.info("Done creating base model") - - log.info("Start creating ctrlnet model") - net = instantiate(self.config.net_ctrl) - conditioner = base_model.conditioner - logvar = base_model.logvar - # initialize controlnet encoder - model = torch.nn.ModuleDict({"net": net, "conditioner": conditioner, "logvar": logvar}) - model.load_state_dict(base_model.state_dict(), strict=False) - - model.base_model = base_model - if not config.finetune_base_model: - model.base_model.requires_grad_(False) - log.critical("Only training ctrlnet model and keeping base model frozen") - else: - log.critical("Also training base model") - log.info("Done creating ctrlnet model") - - self.hint_key = self.config.hint_key["hint_key"] - return model - - @property - def base_net(self): - return self.model.base_model.net - - def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: - super().on_train_start(memory_format) - # self.base_model = self.base_model.to(memory_format=memory_format, **self.tensor_kwargs) - self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) - if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: - if parallel_state.sequence_parallel: - self.base_net.enable_sequence_parallel() - if ( - hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile - ): # compatible with old config - # not tested yet - if torch.__version__ < "2.3": - log.warning( - "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" - "It's very likely there will be no significant speedup from torch.compile.\n" - "Please use at least 24.04 Pytorch container, or imaginaire4:v7 container." - ) - self.base_net = torch.compile(self.base_net, dynamic=False, disable=not self.config.use_torch_compile) - - def load_base_model(self, base_model) -> None: - config = self.config - if config.base_load_from is not None: - checkpoint_path = config.base_load_from["load_path"] - else: - checkpoint_path = "" - - if "*" in checkpoint_path: - # there might be better ways to decide if it's a converted tp checkpoint - mp_rank = parallel_state.get_model_parallel_group().rank() - checkpoint_path = checkpoint_path.replace("*", f"{mp_rank}") - - if checkpoint_path: - log.info(f"Loading base model checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - log.success(f"Complete loading base model checkpoint (local): {checkpoint_path}") - - if state_dict.get("ema") is not None: - # Copy the base model weights from ema model. - log.info("Copying ema to base model") - base_state_dict = {k.replace("-", "."): v for k, v in state_dict["ema"].items()} - elif "model" in state_dict: - # Copy the base model weights from reg model. - log.warning("Using non-EMA base model") - base_state_dict = state_dict["model"] - else: - log.info("Loading from an EMA only model") - base_state_dict = state_dict - try: - base_model.load_state_dict(base_state_dict, strict=False) - except Exception: - log.critical("load model in non-strict mode") - log.critical(non_strict_load_model(base_model, base_state_dict), rank0_only=False) - log.info("Done loading the base model checkpoint.") - - return CtrlNetModel - - -def video_ctrlnet_decorator(base_class: Type[T]) -> Type[T]: - class VideoDiffusionModelWithCtrlWrapper(base_class): - def __init__(self, config): - super().__init__(config) - if hasattr(config, "pixel_corruptor") and config.pixel_corruptor is not None: - self.pixel_corruptor = instantiate(config.pixel_corruptor) - self.pixel_corruptor.to(**self.tensor_kwargs) - else: - self.pixel_corruptor = None - - def get_data_and_condition( - self, data_batch: dict[str, Tensor], **kwargs - ) -> Tuple[Tensor, VideoConditionerWithCtrl]: - # process the control input - hint_key = self.config.hint_key["hint_key"] - is_image_batch = self.is_image_batch(data_batch) - _data = {hint_key: data_batch[hint_key]} - if IS_PREPROCESSED_KEY in data_batch: - _data[IS_PREPROCESSED_KEY] = data_batch[IS_PREPROCESSED_KEY] - if not is_image_batch: - self._normalize_video_databatch_inplace(_data, input_key=hint_key) - # if it is an image batch, the control input is also image - if self.input_image_key in data_batch: - self._augment_image_dim_inplace(_data, input_key=hint_key) - data_batch[hint_key] = _data[hint_key] - # else: - # raise NotImplementedError(f"{self.config.hint_key} is not implemented.") - data_batch["hint_key"] = hint_key - raw_state, latent_state, condition = super().get_data_and_condition(data_batch, **kwargs) - # if not torch.is_grad_enabled() and all(self.config.hint_mask): - use_multicontrol = ( - ("control_weight" in data_batch) - and not isinstance(data_batch["control_weight"], float) - and data_batch["control_weight"].shape[0] > 1 - ) - if use_multicontrol: # encode individual conditions separately - latent_hint = [] - num_conditions = data_batch[data_batch["hint_key"]].size(1) // 3 - for i in range(num_conditions): - cond_mask = [False] * num_conditions - cond_mask[i] = True - latent_hint += [self.encode_latent(data_batch, cond_mask=cond_mask)] - latent_hint = torch.cat(latent_hint) - else: - latent_hint = self.encode_latent(data_batch) - # copied from model.py - is_image_batch = self.is_image_batch(data_batch) - is_video_batch = not is_image_batch - # VAE has randomness. CP/TP group should have the same encoded output. - - latent_hint = _broadcast(latent_hint, to_tp=True, to_cp=is_video_batch) - - # add extra conditions - data_batch["latent_hint"] = latent_hint - setattr(condition, hint_key, latent_hint) - setattr(condition, "base_model", self.model.base_model) - return raw_state, latent_state, condition - - def encode_latent(self, data_batch: dict, cond_mask: list = []) -> torch.Tensor: - x = data_batch[data_batch["hint_key"]] - if torch.is_grad_enabled() and self.pixel_corruptor is not None: - x = self.pixel_corruptor(x) - latent = [] - # control input goes through tokenizer, which always takes 3-input channels - num_conditions = x.size(1) // 3 # input conditions were concatenated along channel dimension - if num_conditions > 1 and self.config.hint_dropout_rate > 0: - if torch.is_grad_enabled(): # during training, randomly dropout some conditions - cond_mask = torch.rand(num_conditions) > self.config.hint_dropout_rate - if not cond_mask.any(): # make sure at least one condition is present - cond_mask[torch.randint(num_conditions, (1,)).item()] = True - elif not cond_mask: # during inference, use hint_mask to indicate which conditions are used - cond_mask = self.config.hint_mask - else: - cond_mask = [True] * num_conditions - for idx in range(0, x.size(1), 3): - x_rgb = x[:, idx : idx + 3] - if self.config.hint_key["grayscale"]: - x_rgb = x_rgb.mean(dim=1, keepdim=True).expand_as(x_rgb) - # if idx == 0: - # x_max = x_rgb - # else: - # x_max = torch.maximum(x_rgb, x_max) - if not cond_mask[idx // 3]: # if the condition is not selected, replace with a black image - x_rgb = torch.zeros_like(x_rgb) - latent.append(self.encode(x_rgb)) - # latent.append(self.encode(x_max)) - latent = torch.cat(latent, dim=1) - return latent - - def compute_loss_with_epsilon_and_sigma( - self, - data_batch: dict[str, Tensor], - x0_from_data_batch: Tensor, - x0: Tensor, - condition: CosmosCondition, - epsilon: Tensor, - sigma: Tensor, - ): - if self.is_image_batch(data_batch): - # Turn off CP - self.net.disable_context_parallel() - self.base_net.disable_context_parallel() - else: - if parallel_state.get_context_parallel_world_size() > 1: - # Turn on CP - cp_group = parallel_state.get_context_parallel_group() - self.net.enable_context_parallel(cp_group) - self.base_net.enable_context_parallel(cp_group) - log.debug("[CP] Split hint_input") - hint_key = self.config.hint_key["hint_key"] - x_hint_raw = getattr(condition, hint_key) - x_hint = split_inputs_cp(x=x_hint_raw, seq_dim=2, cp_group=self.net.cp_group) - setattr(condition, hint_key, x_hint) - return super().compute_loss_with_epsilon_and_sigma( - data_batch, x0_from_data_batch, x0, condition, epsilon, sigma - ) - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - seed_inference: int = 1, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - condition_latent (torch.Tensor): latent tensor in shape B,C,T,H,W as condition to generate video. - - num_condition_t (int): number of condition latent T, used in inference to decide the condition region and config.conditioner.video_cond_bool.condition_location == "first_n" - - condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - # data_batch should be the one processed by self.get_data_and_condition - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - if hasattr(self, "is_extend_model") and self.is_extend_model: - # Add conditions for long video generation. - if self.is_image_batch(data_batch): - condition.data_type = DataType.IMAGE - uncondition.data_type = DataType.IMAGE - else: - if condition_latent is None: - condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) - num_condition_t = 0 - condition_video_augment_sigma_in_inference = 1000 - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, condition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - condition = self.add_condition_pose(data_batch, condition) - - uncondition.video_cond_bool = True # Not do cfg on condition frames - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent, uncondition, num_condition_t - ) - if self.config.conditioner.video_cond_bool.add_pose_condition: - uncondition = self.add_condition_pose(data_batch, uncondition) - - # Add extra conditions for ctrlnet. - latent_hint = data_batch["latent_hint"] - hint_key = data_batch["hint_key"] - setattr(condition, hint_key, latent_hint) - if "use_none_hint" in data_batch and data_batch["use_none_hint"]: - setattr(uncondition, hint_key, None) - else: - setattr(uncondition, hint_key, latent_hint) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized() and not self.is_image_batch(data_batch): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) - - cp_group = parallel_state.get_context_parallel_group() - latent_hint = getattr(condition, hint_key) - latent_hint = split_inputs_cp(latent_hint, seq_dim=2, cp_group=cp_group) - setattr(condition, hint_key, latent_hint) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint) - # else: - # assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - setattr(condition, "base_model", self.model.base_model) - setattr(uncondition, "base_model", self.model.base_model) - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - if self.is_image_batch(data_batch) or not issubclass(base_class, ExtendVideoDiffusionModel): - cond_x0 = self.denoise(noise_x, sigma, condition).x0 - uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 - else: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed_inference=seed_inference, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed_inference=seed_inference, - ).x0_pred_replaced - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - solver_option: COMMON_SOLVER_OPTIONS = "2ab", - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - return_noise: bool = False, - ) -> Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Different from the base model, this function support condition latent as input, it will create a differnt x0_fn if condition latent is given. - If this feature is stablized, we could consider to move this function to the base model. - - Args: - condition_latent (Optional[torch.Tensor]): latent tensor in shape B,C,T,H,W as condition to generate video. - num_condition_t (Optional[int]): number of condition latent T, if None, will use the whole first half - - return_noise (bool): return the initial noise or not, used for ODE pairs generation. Not used here. Kept for conmpatibility. - """ - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - if is_image_batch: - log.debug("image batch, call base model generate_samples_from_batch") - return super().generate_samples_from_batch( - data_batch, - guidance=guidance, - seed=seed, - state_shape=state_shape, - n_sample=n_sample, - is_negative_prompt=is_negative_prompt, - num_steps=num_steps, - ) - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - else: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - # assert condition_latent is not None, "condition_latent should be provided" - - # if self.net.is_context_parallel_enabled: - # data_batch["latent_hint"] = split_inputs_cp(x=data_batch["latent_hint"], seq_dim=2, cp_group=self.net.cp_group) - - x0_fn = self.get_x0_fn_from_batch( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed_inference=seed, - ) - - if sigma_max is None: - sigma_max = self.sde.sigma_max - - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = _broadcast(x_sigma_max, to_tp=True, to_cp=True) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler( - x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option - ) - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - return samples - - def get_patch_based_x0_fn( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - condition_latent: torch.Tensor = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - target_h: int = 2112, - target_w: int = 3840, - patch_h: int = 704, - patch_w: int = 1280, - seed_inference: int = 1, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - The function will split the input into patches, run inference on each patch, then stitch them together. - - Additional args to original function: - target_h (int): final stitched video height - target_w (int): final stitched video width - patch_h (int): video patch height for each network inference - patch_w (int): video patch width for each network inference - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 prediction - """ - assert patch_h <= target_h and patch_w <= target_w - # data_batch should be the one processed by self.get_data_and_condition - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - if hasattr(self, "is_extend_model") and self.is_extend_model: - # Add conditions for long video generation. - if condition_latent is None: - condition_latent = torch.zeros(data_batch["latent_hint"].shape, **self.tensor_kwargs) - num_condition_t = 0 - condition_video_augment_sigma_in_inference = 1000 - - condition.video_cond_bool = True - condition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent[:1], condition, num_condition_t - ) - uncondition.video_cond_bool = True # Not do cfg on condition frames - uncondition = self.add_condition_video_indicator_and_video_input_mask( - condition_latent[:1], uncondition, num_condition_t - ) - # Add extra conditions for ctrlnet. - latent_hint = data_batch["latent_hint"] - hint_key = data_batch["hint_key"] - setattr(condition, hint_key, latent_hint) - if "use_none_hint" in data_batch and data_batch["use_none_hint"]: - setattr(uncondition, hint_key, None) - else: - setattr(uncondition, hint_key, latent_hint) - - # Handle regional prompting information - if "regional_contexts" in data_batch: - setattr(condition, "regional_contexts", data_batch["regional_contexts"]) - # For unconditioned generation, we still need the region masks but not the regional contexts - setattr(uncondition, "regional_contexts", None) # No regional contexts for unconditioned generation - original_region_masks = None - if "region_masks" in data_batch: - original_region_masks = data_batch["region_masks"] - setattr(condition, "region_masks", data_batch["region_masks"]) - # For unconditioned generation, we still need the region masks but not the regional contexts - setattr(uncondition, "region_masks", data_batch["region_masks"]) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized() and not self.is_image_batch(data_batch): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) - cp_group = parallel_state.get_context_parallel_group() - latent_hint = getattr(condition, hint_key) - latent_hint = split_inputs_cp(latent_hint, seq_dim=2, cp_group=cp_group) - - if hasattr(condition, "regional_contexts") and getattr(condition, "regional_contexts") is not None: - regional_contexts = getattr(condition, "regional_contexts") - regional_contexts = split_inputs_cp(regional_contexts, seq_dim=2, cp_group=cp_group) - setattr(condition, "regional_contexts", regional_contexts) - if hasattr(condition, "region_masks") and getattr(condition, "region_masks") is not None: - region_masks = getattr(condition, "region_masks") - region_masks = split_inputs_cp(region_masks, seq_dim=2, cp_group=cp_group) - setattr(condition, "region_masks", region_masks) - - setattr(condition, "base_model", self.model.base_model) - setattr(uncondition, "base_model", self.model.base_model) - if hasattr(self, "hint_encoders"): - self.model.net.hint_encoders = self.hint_encoders - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor): - w, h = target_w, target_h - n_img_w = (w - 1) // patch_w + 1 - n_img_h = (h - 1) // patch_h + 1 - - overlap_size_w = overlap_size_h = 0 - if n_img_w > 1: - overlap_size_w = (n_img_w * patch_w - w) // (n_img_w - 1) - assert n_img_w * patch_w - overlap_size_w * (n_img_w - 1) == w - if n_img_h > 1: - overlap_size_h = (n_img_h * patch_h - h) // (n_img_h - 1) - assert n_img_h * patch_h - overlap_size_h * (n_img_h - 1) == h - - batch_images = noise_x - batch_sigma = sigma - output = [] - for idx, cur_images in enumerate(batch_images): - noise_x = cur_images.unsqueeze(0) - sigma = batch_sigma[idx : idx + 1] - condition.gt_latent = condition_latent[idx : idx + 1] - uncondition.gt_latent = condition_latent[idx : idx + 1] - setattr(condition, hint_key, latent_hint[idx : idx + 1]) - if getattr(uncondition, hint_key) is not None: - setattr(uncondition, hint_key, latent_hint[idx : idx + 1]) - - if self.is_image_batch(data_batch) or not issubclass(base_class, ExtendVideoDiffusionModel): - cond_x0 = self.denoise(noise_x, sigma, condition).x0 - uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 - else: - cond_x0 = self.denoise( - noise_x, - sigma, - condition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed_inference=seed_inference, - ).x0_pred_replaced - uncond_x0 = self.denoise( - noise_x, - sigma, - uncondition, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - seed_inference=seed_inference, - ).x0_pred_replaced - x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) - output.append(x0) - output = rearrange(torch.stack(output), "(n t) b ... -> (b n t) ...", n=n_img_h, t=n_img_w) # 8x3xhxw - final_output = merge_patches_into_video(output, overlap_size_h, overlap_size_w, n_img_h, n_img_w) - final_output = split_video_into_patches(final_output, patch_h, patch_w) - return final_output - - return x0_fn - - def generate_samples_from_patches( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - condition_latent: Union[torch.Tensor, None] = None, - num_condition_t: Union[int, None] = None, - condition_video_augment_sigma_in_inference: float = None, - solver_option: COMMON_SOLVER_OPTIONS = "2ab", - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - target_h: int = 2112, - target_w: int = 3840, - patch_h: int = 704, - patch_w: int = 1280, - ) -> Tensor: - """ - Generate samples from the batch using patch-based inference. During each denoising step, it will denoise each patch - separately then average the overlapping regions. - - Additional args to original function: - target_h (int): final stitched video height - target_w (int): final stitched video width - patch_h (int): video patch height for each network inference - patch_w (int): video patch width for each network inference - """ - assert patch_h <= target_h and patch_w <= target_w - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - if is_image_batch: - log.debug("image batch, call base model generate_samples_from_batch") - return super().generate_samples_from_batch( - data_batch, - guidance=guidance, - seed=seed, - state_shape=state_shape, - n_sample=n_sample, - is_negative_prompt=is_negative_prompt, - num_steps=num_steps, - ) - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - else: - log.debug(f"Default Video state shape is used. {self.state_shape}") - state_shape = self.state_shape - - x0_fn = self.get_patch_based_x0_fn( - data_batch, - guidance, - is_negative_prompt=is_negative_prompt, - condition_latent=condition_latent, - num_condition_t=num_condition_t, - condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference, - target_h=target_h, - target_w=target_w, - patch_h=patch_h, - patch_w=patch_w, - seed_inference=seed, - ) - - if sigma_max is None: - sigma_max = self.sde.sigma_max - - if x_sigma_max is None: - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * sigma_max - ) - - if self.net.is_context_parallel_enabled: - x_sigma_max = _broadcast(x_sigma_max, to_tp=True, to_cp=True) - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - samples = self.sampler( - x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option - ) - if self.net.is_context_parallel_enabled: - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - - return samples - - @torch.no_grad() - def validation_step( - self, data: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """ - save generated videos - """ - raw_data, x0, condition = self.get_data_and_condition(data) - guidance = data["guidance"] - sigma_max = data["sigma_max"] - is_negative_prompt = data["is_negative_prompt"] - data = misc.to(data, **self.tensor_kwargs) - x_sigma_max = None - if sigma_max is not None: - x_sigma_max = self.get_x_from_clean(x0, sigma_max) - sample = self.generate_samples_from_batch( - data, - guidance=guidance, - # make sure no mismatch and also works for cp - state_shape=x0.shape[1:], - n_sample=x0.shape[0], - x_sigma_max=x_sigma_max, - sigma_max=sigma_max, - is_negative_prompt=is_negative_prompt, - ) - sample = self.decode(sample) - gt = raw_data - hint = data[data["hint_key"]][:, :3] - result = torch.cat([hint, sample], dim=3) - gt = torch.cat([hint, gt], dim=3) - caption = data["ai_caption"] - return {"gt": gt, "result": result, "caption": caption}, torch.tensor([0]).to(**self.tensor_kwargs) - - return VideoDiffusionModelWithCtrlWrapper - - -@video_ctrlnet_decorator -@ctrlnet_decorator -class VideoDiffusionModelWithCtrl(ExtendVideoDiffusionModel): - pass - - -@diffusion_fsdp_class_decorator -@video_ctrlnet_decorator -@ctrlnet_decorator -class VideoDiffusionFSDPModelWithCtrl(ExtendVideoDiffusionModel): - pass - - -@video_ctrlnet_decorator -@ctrlnet_decorator -class ShortVideoDiffusionModelWithCtrl(VideoDiffusionModel): - pass - - -@diffusion_fsdp_class_decorator -@video_ctrlnet_decorator -@ctrlnet_decorator -class ShortVideoDiffusionFSDPModelWithCtrl(VideoDiffusionModel): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_image.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_image.py deleted file mode 100644 index 53ec975adb84739d37f22512a3cf041669108ee0..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_image.py +++ /dev/null @@ -1,928 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from contextlib import contextmanager -from dataclasses import dataclass, fields -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, TypeVar - -import numpy as np -import torch -import torch.nn.functional as F -from megatron.core import parallel_state -from torch.distributed.fsdp import FullStateDictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy, StateDictType -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -from torch.nn.modules.module import _IncompatibleKeys - -from cosmos_transfer1.diffusion.diffusion.modules.denoiser_scaling import EDMScaling -from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler -from cosmos_transfer1.diffusion.diffusion.types import DenoisePrediction -from cosmos_transfer1.diffusion.functional.batch_ops import batch_mul -from cosmos_transfer1.diffusion.module.blocks import FourierFeatures -from cosmos_transfer1.diffusion.module.pretrained_vae import BaseVAE -from cosmos_transfer1.diffusion.training.functional.loss import create_per_sample_loss_mask -from cosmos_transfer1.diffusion.training.utils.fsdp_helper import apply_fsdp_checkpointing, hsdp_device_mesh -from cosmos_transfer1.diffusion.training.utils.optim_instantiate import get_base_scheduler -from cosmos_transfer1.utils import distributed, log, misc -from cosmos_transfer1.utils.ema import FastEmaModelUpdater -from cosmos_transfer1.utils.lazy_config import LazyDict -from cosmos_transfer1.utils.lazy_config import instantiate as lazy_instantiate -from cosmos_transfer1.utils.model import Model - - -@dataclass -class CosmosCondition: - crossattn_emb: torch.Tensor - crossattn_mask: torch.Tensor - padding_mask: Optional[torch.Tensor] = None - scalar_feature: Optional[torch.Tensor] = None - - def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: - return {f.name: getattr(self, f.name) for f in fields(self)} - - -class DiffusionModel(Model): - def __init__(self, config): - super().__init__() - - self.config = config - - # how many sample have been processed - self.sample_counter = 0 - self.precision = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - }[config.precision] - self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} - log.warning(f"DiffusionModel: precision {self.precision}") - # Timer passed to network to detect slow ranks. - # 1. set data keys and data information - self.sigma_data = config.sigma_data - self.state_shape = list(config.latent_shape) - self.setup_data_key() - - # 2. setup up diffusion processing and scaling~(pre-condition), sampler - self.sde = lazy_instantiate(config.sde) - self.sampler = Sampler() - self.scaling = EDMScaling(self.sigma_data) - - # 3. vae - with misc.timer("DiffusionModel: set_up_vae"): - self.vae: BaseVAE = lazy_instantiate(config.tokenizer) - assert ( - self.vae.latent_ch == self.state_shape[0] - ), f"latent_ch {self.vae.latent_ch} != state_shape {self.state_shape[0]}" - - # 4. Set up loss options, including loss masking, loss reduce and loss scaling - self.loss_masking: Optional[Dict] = config.loss_masking - self.loss_reduce = getattr(config, "loss_reduce", "mean") - assert self.loss_reduce in ["mean", "sum"] - self.loss_scale = getattr(config, "loss_scale", 1.0) - log.critical(f"Using {self.loss_reduce} loss reduce with loss scale {self.loss_scale}") - log.critical(f"Enable loss masking: {config.loss_mask_enabled}") - - # 5. diffusion neural networks part - self.set_up_model() - - def setup_data_key(self) -> None: - self.input_data_key = self.config.input_data_key - - def build_model(self) -> torch.nn.ModuleDict: - config = self.config - net = lazy_instantiate(config.net) - conditioner = lazy_instantiate(config.conditioner) - logvar = torch.nn.Sequential( - FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False) - ) - - return torch.nn.ModuleDict( - { - "net": net, - "conditioner": conditioner, - "logvar": logvar, - } - ) - - @misc.timer("DiffusionModel: set_up_model") - def set_up_model(self): - config = self.config - self.model = self.build_model() - if config.ema.enabled: - with misc.timer("DiffusionModel: instantiate ema"): - config.ema.model = self.model - self.model_ema = lazy_instantiate(config.ema) - config.ema.model = None - else: - self.model_ema = None - - @property - def net(self): - return self.model.net - - @property - def conditioner(self): - return self.model.conditioner - - def on_before_zero_grad( - self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int - ) -> None: - """ - update the model_ema - """ - if self.config.ema.enabled: - self.model_ema.update_average(self.model, iteration) - - def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: - if self.config.ema.enabled: - self.model_ema.to(dtype=torch.float32) - if hasattr(self.vae, "reset_dtype"): - self.vae.reset_dtype() - self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs) - - if hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile: # compatible with old config - if torch.__version__ < "2.3": - log.warning( - "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" - "It's very likely there will be no significant speedup from torch.compile.\n" - "Please use at least 24.04 Pytorch container." - ) - # Increasing cache size. It's required because of the model size and dynamic input shapes resulting in - # multiple different triton kernels. For 28 TransformerBlocks, the cache limit of 256 should be enough for - # up to 9 different input shapes, as 28*9 < 256. If you have more Blocks or input shapes, and you observe - # graph breaks at each Block (detectable with torch._dynamo.explain) or warnings about - # exceeding cache limit, you may want to increase this size. - # Starting with 24.05 Pytorch container, the default value is 256 anyway. - # You can read more about it in the comments in Pytorch source code under path torch/_dynamo/cache_size.py. - torch._dynamo.config.accumulated_cache_size_limit = 256 - # dynamic=False means that a separate kernel is created for each shape. It incurs higher compilation costs - # at initial iterations, but can result in more specialized and efficient kernels. - # dynamic=True currently throws errors in pytorch 2.3. - self.model.net = torch.compile(self.model.net, dynamic=False, disable=not self.config.use_torch_compile) - - def compute_loss_with_epsilon_and_sigma( - self, - data_batch: dict[str, torch.Tensor], - x0_from_data_batch: torch.Tensor, - x0: torch.Tensor, - condition: CosmosCondition, - epsilon: torch.Tensor, - sigma: torch.Tensor, - ): - """ - Compute loss givee epsilon and sigma - - This method is responsible for computing loss give epsilon and sigma. It involves: - 1. Adding noise to the input data using the SDE process. - 2. Passing the noisy data through the network to generate predictions. - 3. Computing the loss based on the difference between the predictions and the original data, \ - considering any configured loss weighting. - - Args: - data_batch (dict): raw data batch draw from the training data loader. - x0_from_data_batch: raw image/video - x0: image/video latent - condition: text condition - epsilon: noise - sigma: noise level - - Returns: - tuple: A tuple containing four elements: - - dict: additional data that used to debug / logging / callbacks - - Tensor 1: kendall loss, - - Tensor 2: MSE loss, - - Tensor 3: EDM loss - - Raises: - AssertionError: If the class is conditional, \ - but no number of classes is specified in the network configuration. - - Notes: - - The method handles different types of conditioning - - The method also supports Kendall's loss - """ - # Get the mean and stand deviation of the marginal probability distribution. - mean, std = self.sde.marginal_prob(x0, sigma) - # Generate noisy observations - xt = mean + batch_mul(std, epsilon) # corrupted data - # make prediction - model_pred = self.denoise(xt, sigma, condition) - # loss weights for different noise levels - weights_per_sigma = self.get_per_sigma_loss_weights(sigma=sigma) - # extra weight for each sample, for example, aesthetic weight, camera weight - weights_per_sample = self.get_per_sample_weight(data_batch, x0_from_data_batch.shape[0]) - # extra loss mask for each sample, for example, human faces, hands - loss_mask_per_sample = self.get_per_sample_loss_mask(data_batch, x0_from_data_batch.shape, x0.shape) - pred_mse = (x0 - model_pred.x0) ** 2 * loss_mask_per_sample - edm_loss = batch_mul(pred_mse, weights_per_sigma * weights_per_sample) - if self.config.loss_add_logvar: - kendall_loss = batch_mul(edm_loss, torch.exp(-model_pred.logvar).view(-1)).flatten( - start_dim=1 - ) + model_pred.logvar.view(-1, 1) - else: - kendall_loss = edm_loss.flatten(start_dim=1) - output_batch = { - "x0": x0, - "xt": xt, - "sigma": sigma, - "weights_per_sigma": weights_per_sigma, - "weights_per_sample": weights_per_sample, - "loss_mask_per_sample": loss_mask_per_sample, - "condition": condition, - "model_pred": model_pred, - "mse_loss": pred_mse.mean(), - "edm_loss": edm_loss.mean(), - } - return output_batch, kendall_loss, pred_mse, edm_loss - - def training_step( - self, data_batch: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """ - Performs a single training step for the diffusion model. - - This method is responsible for executing one iteration of the model's training. It involves: - 1. Adding noise to the input data using the SDE process. - 2. Passing the noisy data through the network to generate predictions. - 3. Computing the loss based on the difference between the predictions and the original data, \ - considering any configured loss weighting. - - Args: - data_batch (dict): raw data batch draw from the training data loader. - iteration (int): Current iteration number. - - Returns: - tuple: A tuple containing two elements: - - dict: additional data that used to debug / logging / callbacks - - Tensor: The computed loss for the training step as a PyTorch Tensor. - - Raises: - AssertionError: If the class is conditional, \ - but no number of classes is specified in the network configuration. - - Notes: - - The method handles different types of conditioning - - The method also supports Kendall's loss - """ - # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. - x0_from_data_batch, x0, condition = self.get_data_and_condition(data_batch) - - # Sample pertubation noise levels and N(0, 1) noises - sigma, epsilon = self.draw_training_sigma_and_epsilon(x0.size(), condition) - - output_batch, kendall_loss, pred_mse, edm_loss = self.compute_loss_with_epsilon_and_sigma( - data_batch, x0_from_data_batch, x0, condition, epsilon, sigma - ) - - if self.loss_reduce == "mean": - kendall_loss = kendall_loss.mean() * self.loss_scale - elif self.loss_reduce == "sum": - kendall_loss = kendall_loss.sum(dim=1).mean() * self.loss_scale - else: - raise ValueError(f"Invalid loss_reduce: {self.loss_reduce}") - - return output_batch, kendall_loss - - def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction: - """ - Performs denoising on the input noise data, noise level, and condition - - Args: - xt (torch.Tensor): The input noise data. - sigma (torch.Tensor): The noise level. - condition (CosmosCondition): conditional information, generated from self.conditioner - - Returns: - DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ - noise prediction (eps_pred) and optional confidence (logvar). - """ - - if getattr(self.config, "use_dummy_temporal_dim", False): - # When using video DiT model for image, we need to use a dummy temporal dimension. - xt = xt.unsqueeze(2) - - xt = xt.to(**self.tensor_kwargs) - sigma = sigma.to(**self.tensor_kwargs) - # get precondition for the network - c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma) - - # forward pass through the network - net_output = self.net( - x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf - timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf - **condition.to_dict(), - ) - - logvar = self.model.logvar(c_noise) - x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output) - - # get noise prediction based on sde - eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma) - - if getattr(self.config, "use_dummy_temporal_dim", False): - x0_pred = x0_pred.squeeze(2) - eps_pred = eps_pred.squeeze(2) - - return DenoisePrediction(x0_pred, eps_pred, logvar) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - return self.vae.encode(state) * self.sigma_data - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - return self.vae.decode(latent / self.sigma_data) - - def draw_training_sigma_and_epsilon(self, x0_size: int, condition: Any) -> torch.Tensor: - del condition - batch_size = x0_size[0] - epsilon = torch.randn(x0_size, **self.tensor_kwargs) - return self.sde.sample_t(batch_size).to(**self.tensor_kwargs), epsilon - - def get_data_and_condition(self, data_batch: dict[str, torch.Tensor]) -> Tuple[torch.Tensor, CosmosCondition]: - """ - processing data batch draw from data loader and return data and condition that used for denoising task - - Returns: - raw_state (tensor): the image / video data that feed to vae - latent_state (tensor): nosie-free state, the vae latent state - condition (CosmosCondition): condition information for conditional generation. Generated from conditioner - """ - raw_state = data_batch[self.input_data_key] - latent_state = self.encode(raw_state) - condition = self.conditioner(data_batch) - return raw_state, latent_state, condition - - def get_per_sample_weight(self, data_batch: dict[str, torch.Tensor], batch_size: int): - r""" - extra weight for each sample, for example, aesthetic weight - Args: - data_batch: raw data batch draw from the training data loader. - batch_size: int, the batch size of the input data - """ - aesthetic_cfg = getattr(self.config, "aesthetic_finetuning", None) - if (aesthetic_cfg is not None) and getattr(aesthetic_cfg, "enabled", False): - sample_weight = data_batch["aesthetic_weight"] - else: - sample_weight = torch.ones(batch_size, **self.tensor_kwargs) - - camera_cfg = getattr(self.config, "camera_sample_weight", None) - if (camera_cfg is not None) and getattr(camera_cfg, "enabled", False): - sample_weight *= 1 + (data_batch["camera_attributes"][:, 1:].sum(dim=1) != 0) * (camera_cfg.weight - 1) - return sample_weight - - def get_per_sample_loss_mask(self, data_batch, raw_x_shape, latent_x_shape): - """ - extra loss mask for each sample, for example, human faces, hands. - - Args: - data_batch (dict): raw data batch draw from the training data loader. - raw_x_shape (tuple): shape of the input data. We need the raw_x_shape for necessary resize operation. - latent_x_shape (tuple): shape of the latent data - """ - if self.config.loss_mask_enabled: - raw_x_shape = [raw_x_shape[0], 1, *raw_x_shape[2:]] - weights = create_per_sample_loss_mask( - self.loss_masking, data_batch, raw_x_shape, torch.get_default_dtype(), "cuda" - ) - return F.interpolate(weights, size=latent_x_shape[2:], mode="bilinear") - - return 1.0 - - def get_per_sigma_loss_weights(self, sigma: torch.Tensor): - """ - Args: - sigma (tensor): noise level - - Returns: - loss weights per sigma noise level - """ - return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - - def generate_samples(self, batch_size: int, condition: CosmosCondition) -> torch.Tensor: - """ - Generate samples with given condition. It is WITHOUT classifier-free-guidance. - - Args: - batch_size (int): - condition (CosmosCondition): condition information generated from self.conditioner - """ - x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max - - def x0_fn(x, t): - return self.denoise(x, t, condition).x0 # ODE function - - return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) - - def generate_cfg_samples( - self, batch_size: int, condition: CosmosCondition, uncondition: CosmosCondition, guidance=1.5 - ) -> torch.Tensor: - """ - Generate samples with with classifier-free-guidance. - - Args: - batch_size (int): - condition (CosmosCondition): condition information generated from self.conditioner - uncondition (CosmosCondition): uncondition information, possibily generated from self.conditioner - """ - x_sigma_max = torch.randn(batch_size, *self.state_shape, **self.tensor_kwargs) * self.sde.sigma_max - - def x0_fn(x, t): - cond_x0 = self.denoise(x, t, condition).x0 - uncond_x0 = self.denoise(x, t, uncondition).x0 - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return self.sampler(x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max) - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise(noise_x, sigma, condition).x0 - uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 - return cond_x0 + guidance * (cond_x0 - uncond_x0) - - return x0_fn - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Optional[Tuple] = None, - n_sample: Optional[int] = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - solver_option: COMMON_SOLVER_OPTIONS = "2ab", - ) -> torch.Tensor: - """ - Args: - data_batch (dict): raw data batch draw from the training data loader. - iteration (int): Current iteration number. - guidance (float): guidance weights - seed (int): random seed - state_shape (tuple): shape of the state, default to self.state_shape if not provided - n_sample (int): number of samples to generate - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - num_steps (int): number of steps for the diffusion process - solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) - """ - x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) - batch_size = n_sample or data_batch[self.input_data_key].shape[0] - state_shape = state_shape or self.state_shape - x_sigma_max = ( - misc.arch_invariant_rand( - (batch_size,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * self.sde.sigma_max - ) - return self.sampler( - x0_fn, x_sigma_max, sigma_max=self.sde.sigma_max, num_steps=num_steps, solver_option=solver_option - ) - - @torch.no_grad() - def validation_step( - self, data: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """ - Current code does nothing. - """ - return {}, torch.tensor(0).to(**self.tensor_kwargs) - - @torch.no_grad() - def forward(self, xt, t, condition: CosmosCondition): - """ - Performs denoising on the input noise data, noise level, and condition - - Args: - xt (torch.Tensor): The input noise data. - sigma (torch.Tensor): The noise level. - condition (CosmosCondition): conditional information, generated from self.conditioner - - Returns: - DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ - noise prediction (eps_pred) and optional confidence (logvar). - """ - return self.denoise(xt, t, condition) - - def init_optimizer_scheduler( - self, optimizer_config: LazyDict, scheduler_config: LazyDict - ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: - """Creates the optimizer and scheduler for the model. - - Args: - config_model (ModelConfig): The config object for the model. - - Returns: - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - """ - optimizer = lazy_instantiate(optimizer_config, model=self.model) - scheduler = get_base_scheduler(optimizer, self, scheduler_config) - return optimizer, scheduler - - def state_dict(self) -> Dict[str, Any]: - """ - Returns the current state of the model as a dictionary. - - Returns: - Dict: The current state of the model as a dictionary. - """ - return { - "model": self.model.state_dict(), - "ema": self.model_ema.state_dict() if self.config.ema.enabled else None, - } - - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False): - """ - Loads a state dictionary into the model and optionally its EMA counterpart. - Different from torch strict=False mode, the method will not raise error for unmatched state shape while raise warning. - - Parameters: - state_dict (Mapping[str, Any]): A dictionary containing separate state dictionaries for the model and - potentially for an EMA version of the model under the keys 'model' and 'ema', respectively. - strict (bool, optional): If True, the method will enforce that the keys in the state dict match exactly - those in the model and EMA model (if applicable). Defaults to True. - assign (bool, optional): If True and in strict mode, will assign the state dictionary directly rather than - matching keys one-by-one. This is typically used when loading parts of state dicts - or using customized loading procedures. Defaults to False. - """ - if strict: - if self.config.ema.enabled and state_dict.get("ema") is not None: - ema_results: _IncompatibleKeys = self.model_ema.load_state_dict( - state_dict["ema"], strict=strict, assign=assign - ) - reg_results: _IncompatibleKeys = self.model.load_state_dict( - state_dict["model"], strict=strict, assign=assign - ) - if self.config.ema.enabled and state_dict.get("ema") is not None: - return _IncompatibleKeys( - ema_results.missing_keys + reg_results.missing_keys, - ema_results.unexpected_keys + reg_results.unexpected_keys, - ) - return reg_results - else: - from cosmos_transfer1.diffusion.inference.inference_utils import non_strict_load_model - - log.critical("load model in non-strict mode") - log.critical(non_strict_load_model(self.model, state_dict["model"]), rank0_only=False) - if self.config.ema.enabled and state_dict.get("ema") is not None: - log.critical("load ema model in non-strict mode") - log.critical(non_strict_load_model(self.model_ema, state_dict["ema"]), rank0_only=False) - - def get_ckpt_postfix(self) -> Tuple[str, int, int]: - """Get the checkpoint file postfix. - - Args: - iteration (int): The current iteration number. - - Returns: - postfix (str): The postfix of the checkpoint file. - rank_to_save ema (int), we will not save each ema model in each rank, \ - ema model with same rate will be saved once - total_ema_num (int) - """ - total_ema_num = min(self.config.ema.num, distributed.get_world_size()) - rank = distributed.get_rank() - if rank == 0: - return "", 0, total_ema_num - if self.config.ema.enabled: - if rank < self.config.ema.num: - return f"_RANK{rank}", rank, total_ema_num - return "", 0, total_ema_num # use rank 0 to save the checkpoint - - @contextmanager - def ema_scope(self, context=None, is_cpu=False): - if self.config.ema.enabled: - self.model_ema.cache(self.model.parameters(), is_cpu=is_cpu) - self.model_ema.copy_to(self.model) - if context is not None: - log.info(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.config.ema.enabled: - self.model_ema.restore(self.model.parameters()) - if context is not None: - log.info(f"{context}: Restored training weights") - - -T = TypeVar("T", bound=DiffusionModel) - - -def diffusion_fsdp_class_decorator(base_class: Type[T]) -> Type[T]: - """ - Decorator for the FSDP class for the diffusion model, which handles the FSDP specific logic for the diffusion model. - """ - - class FSDPClass(base_class): - """ - Handle FSDP specific logic for the diffusion model. Including: - - FSDP model initialization - - FSDP model / optimizer save and loading - - Different from the original DiffusionModel, the impl of multi-rank EMA is a bit hacky. \ - We need to make sure sharded model weights for EMA and regular model are the same. - """ - - def __init__(self, config, fsdp_checkpointer: Any): - self.fsdp_checkpointer = fsdp_checkpointer - super().__init__(config) - - def set_up_model(self): - config = self.config - - # 1. build FSDP sharding strategy and device_mesh - strategy = { - "full": ShardingStrategy.FULL_SHARD, - "hybrid": ShardingStrategy.HYBRID_SHARD, - }[config.fsdp.sharding_strategy] - log.critical(f"Using {strategy} sharding strategy for FSDP") - - if config.fsdp.sharding_strategy == "hybrid": - sharding_group_size = getattr(config.fsdp, "sharding_group_size", 8) - device_mesh = hsdp_device_mesh( - sharding_group_size=sharding_group_size, - ) - shard_group = device_mesh.get_group(mesh_dim="shard") - replicate_group = device_mesh.get_group(mesh_dim="replicate") - fsdp_process_group = (shard_group, replicate_group) - else: - device_mesh = hsdp_device_mesh( - sharding_group_size=distributed.get_world_size(), - ) - shard_group = device_mesh.get_group(mesh_dim="shard") - fsdp_process_group = shard_group - - # We piggyback the `device_mesh` to megatron-core's `parallel_state` for global access. - # This is not megatron-core's original API. - parallel_state.fsdp_device_mesh = device_mesh - - def get_wrap_policy(_model): - if not hasattr(_model.net, "fsdp_wrap_block_cls"): - raise ValueError( - "Networks does not have fsdp_wrap_block_cls attribute, please check the net definition" - ) - fsdp_blocks_cls = _model.net.fsdp_wrap_block_cls - fsdp_blocks_cls = ( - list(fsdp_blocks_cls) if isinstance(fsdp_blocks_cls, (list, tuple, set)) else [fsdp_blocks_cls] - ) - log.critical(f"Using FSDP blocks {fsdp_blocks_cls}") - - log.critical(f"Using wrap policy {config.fsdp.policy}") - if config.fsdp.policy == "size": - min_num_params = getattr(config.fsdp, "min_num_params", 100) - log.critical(f"Using {min_num_params} as the minimum number of parameters for auto-wrap policy") - wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) - else: - from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - - wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=set(fsdp_blocks_cls), - ) - return wrap_policy - - # 2. build naive pytorch model and load weights if exists - replica_idx, shard_idx = device_mesh.get_coordinate() - # 2.1 handle ema case first, since float32 is more expensive - if config.ema.enabled: - with misc.timer("Creating PyTorch model and loading weights for ema"): - model_ema = self.build_model().float() - model_ema.cuda().eval().requires_grad_(False) - if distributed.get_rank() == 0: - # only load model in rank0 to reduce network traffic - self.fsdp_checkpointer.load_model_during_init(model_ema, is_ema=True) - # sync ema model weights from rank0 - with misc.timer("Sync model states for EMA model"): - #! this is IMPORTANT, see the following comment about regular model for details - #! we broadcast the ema model first, since it is fp32 and costs more memory - distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="shard")) - torch.cuda.empty_cache() - distributed.sync_model_states(model_ema, device_mesh.get_group(mesh_dim="replicate")) - torch.cuda.empty_cache() - # for ema model with dfiferent rate, we download the model when necessary - if shard_idx == 0 and replica_idx > 0 and replica_idx < config.ema.num: - print("loading ema model in rank", replica_idx) - self.fsdp_checkpointer.load_model_during_init( - model_ema, - is_ema=True, - ema_id=replica_idx, - ) - print("finish loading ema model in rank", replica_idx) - # 2.1.2 create FSDP model for ema model - with misc.timer("Creating FSDP model for EMA model"): - self.model_ema = FSDP( - model_ema, - sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync - process_group=device_mesh.get_group(mesh_dim=1), - sharding_strategy=ShardingStrategy.FULL_SHARD, - auto_wrap_policy=get_wrap_policy(model_ema), - device_id=torch.cuda.current_device(), - limit_all_gathers=True, - ) - - # extra ema model upate logic to the model - self.model_ema_worker = FastEmaModelUpdater() - s = 0.1 - replica_idx, shard_idx = device_mesh.get_coordinate() - divider = 2**replica_idx if replica_idx < config.ema.num else 1 - if replica_idx < config.ema.num: - if shard_idx == 0: - print(f"EMA: rank {replica_idx}, rate {config.ema.rate / divider}") - s = config.ema.rate / divider - self.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() - - torch.cuda.empty_cache() - - # 2.2 handle regular model - with misc.timer("Creating PyTorch model and loading weights for regular model"): - model = self.build_model().cuda().to(**self.tensor_kwargs) - - if distributed.get_rank() == 0: - # only load model in rank0 to reduce network traffic and sync later - self.fsdp_checkpointer.load_model_during_init(model, is_ema=False) - - #! overwrite the forward method so that it will invoke the FSDP-specific pre- and post-forward sharding logic - model.forward = super().training_step - #! this is IMPORTANT, though following two lines are identical to sync_module_states=True in FSDP - #! we do it twice so that following line can warm up and avoid OOM in 128+ nodes settings - #! without it, peak mem : reg_model + ema_model + FSDP overhead + nccl communication initialization overhead - #! with it, peak men: reg_model + ema_model + FSDP overhead - #! it is tricky, but it works! - with misc.timer("Sync model states for regular model"): - distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="shard")) - torch.cuda.empty_cache() - distributed.sync_model_states(model, device_mesh.get_group(mesh_dim="replicate")) - torch.cuda.empty_cache() - - with misc.timer("Creating FSDP model"): - self.model = FSDP( - model.to(**self.tensor_kwargs), - sync_module_states=True, # it can reduce network traffic by only loading model in rank0 and sync - sharding_strategy=strategy, - auto_wrap_policy=get_wrap_policy(model), - process_group=fsdp_process_group, - limit_all_gathers=True, - ) - - if self.config.fsdp.checkpoint: - fsdp_blocks_cls = model.net.fsdp_wrap_block_cls - fsdp_blocks_cls = ( - list(fsdp_blocks_cls) - if isinstance(fsdp_blocks_cls, (list, tuple, set)) - else [fsdp_blocks_cls] - ) - log.critical(f"Applying FSDP checkpointing with FSDP blocks: {fsdp_blocks_cls}") - apply_fsdp_checkpointing(self.model, list_block_cls=fsdp_blocks_cls) - - torch.cuda.empty_cache() - - def on_before_zero_grad( - self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int - ) -> None: - del scheduler, optimizer - - if self.config.ema.enabled: - # calculate beta for EMA update - if iteration == 0: - beta = 0.0 - else: - i = iteration + 1 - beta = (1 - 1 / i) ** (self.ema_exp_coefficient + 1) - self.model_ema_worker.update_average(self.model, self.model_ema, beta=beta) - - def training_step( - self, data_batch: Dict[str, torch.Tensor], iteration: int - ) -> Tuple[Dict[str, torch.Tensor] | torch.Tensor]: - # ! Important!!! - # ! make sure the training step is the same as the forward method~(training_step in the super class) - # ! this is necessary to trigger the FSDP-specific pre- and post-forward sharding logic - return self.model(data_batch, iteration) - - def state_dict(self) -> Dict: - raise NotImplementedError( - "FSDPDiffModle does not support state_dict, use state_dict_model and FSDPCheckpointer" - ) - - @misc.timer("FSDP state_dict_model") - def state_dict_model(self) -> Dict: - with FSDP.summon_full_params(self.model): - pass - with FSDP.state_dict_type( - self.model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - ): - model_state = self.model.state_dict() - if self.config.ema.enabled: - with FSDP.summon_full_params(self.model_ema): - pass - with FSDP.state_dict_type( - self.model_ema, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - ema_model_state = self.model_ema.state_dict() - else: - ema_model_state = None - return { - "model": model_state, - "ema": ema_model_state, - } - - def load_state_dict(self, state_dict: Dict, strict: bool = True, assign: bool = False) -> None: - raise NotImplementedError("FSDPDiffModle does not support load_state_dict, using FSDPCheckpointer") - - def init_optimizer_scheduler( - self, optimizer_config: LazyDict, scheduler_config: LazyDict - ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: - optimizer, scheduler = super().init_optimizer_scheduler(optimizer_config, scheduler_config) - self.fsdp_checkpointer.load_optim_scheduler_during_init( - self.model, - optimizer, - scheduler, - ) - return optimizer, scheduler - - @contextmanager - def ema_scope(self, context=None, is_cpu=False): - if self.config.ema.enabled: - self.model_ema_worker.cache(self.model.parameters(), is_cpu=is_cpu) - self.model_ema_worker.copy_to(src_model=self.model_ema, tgt_model=self.model) - if context is not None: - log.info(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.config.ema.enabled: - self.model_ema_worker.restore(self.model.parameters()) - if context is not None: - log.info(f"{context}: Restored training weights") - - def get_ckpt_postfix(self) -> Tuple[str, int]: - """Get the checkpoint file postfix. check FSDPCheckpointer for more details - - Args: - iteration (int): The current iteration number. - - Returns: - postfix (str): The postfix of the checkpoint file. - replicate_idx, shard_idx (int), current gpu replicate_idx, shard_idx in FSDP \ - we will not save each ema model in each GPU, \ - ema model with same rate will be saved once - total_ema_num (int) - """ - mesh_shape = parallel_state.fsdp_device_mesh.shape - total_ema_num = min(self.config.ema.num, mesh_shape[0]) - replicate_idx, shard_idx = parallel_state.fsdp_device_mesh.get_coordinate() - if replicate_idx == 0: - return "", 0, shard_idx, total_ema_num - if self.config.ema.enabled: - if replicate_idx < self.config.ema.num: - return f"_RANK{replicate_idx}", replicate_idx, shard_idx, total_ema_num - return "", replicate_idx, shard_idx, total_ema_num - - return FSDPClass - - -@diffusion_fsdp_class_decorator -class FSDPDiffusionModel(DiffusionModel): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_multiview.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_multiview.py deleted file mode 100644 index ec3fa3b8ba49cbabbc0052168ca1424a859d7d9e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/models/model_multiview.py +++ /dev/null @@ -1,241 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from typing import Callable, Dict, Optional, Tuple, Union - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor - -from cosmos_transfer1.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS -from cosmos_transfer1.diffusion.module.parallel import cat_outputs_cp, split_inputs_cp -from cosmos_transfer1.diffusion.training.models.model import DiffusionModel, broadcast_condition -from cosmos_transfer1.diffusion.training.models.model_image import CosmosCondition, diffusion_fsdp_class_decorator -from cosmos_transfer1.utils import log, misc - - -def deepcopy_no_copy_model(obj): - """ - We need to create a copy of the condition construct such that condition masks can be adjusted dynamically, but - the controlnet encoder plug-in also uses the condition construct to pass along the base_model object which cannot be - deep-copied, hence this funciton - """ - if hasattr(obj, "base_model") and obj.base_model is not None: - my_base_model = obj.base_model - obj.base_model = None - copied_obj = copy.deepcopy(obj) - copied_obj.base_model = my_base_model - obj.base_model = my_base_model - else: - copied_obj = copy.deepcopy(obj) - return copied_obj - - -class MultiviewDiffusionModel(DiffusionModel): - def __init__(self, config): - super().__init__(config) - self.n_views = config.n_views - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - state = rearrange(state, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - encoded_state = self.vae.encode(state) - encoded_state = rearrange(encoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) * self.sigma_data - return encoded_state - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - latent = rearrange(latent, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - decoded_state = self.vae.decode(latent / self.sigma_data) - decoded_state = rearrange(decoded_state, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - return decoded_state - - def compute_loss_with_epsilon_and_sigma( - self, - data_batch: dict[str, torch.Tensor], - x0_from_data_batch: torch.Tensor, - x0: torch.Tensor, - condition: CosmosCondition, - epsilon: torch.Tensor, - sigma: torch.Tensor, - ): - if self.is_image_batch(data_batch): - # Turn off CP - self.net.disable_context_parallel() - else: - if parallel_state.is_initialized(): - if parallel_state.get_context_parallel_world_size() > 1: - # Turn on CP - cp_group = parallel_state.get_context_parallel_group() - self.net.enable_context_parallel(cp_group) - log.debug("[CP] Split x0 and epsilon") - - x0 = rearrange(x0, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - epsilon = rearrange(epsilon, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - - x0 = split_inputs_cp(x=x0, seq_dim=2, cp_group=self.net.cp_group) - epsilon = split_inputs_cp(x=epsilon, seq_dim=2, cp_group=self.net.cp_group) - - x0 = rearrange(x0, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - epsilon = rearrange(epsilon, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - output_batch, kendall_loss, pred_mse, edm_loss = super( - DiffusionModel, self - ).compute_loss_with_epsilon_and_sigma(data_batch, x0_from_data_batch, x0, condition, epsilon, sigma) - if not self.is_image_batch(data_batch): - if self.loss_reduce == "sum" and parallel_state.get_context_parallel_world_size() > 1: - kendall_loss *= parallel_state.get_context_parallel_world_size() - - return output_batch, kendall_loss, pred_mse, edm_loss - - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - solver_option: COMMON_SOLVER_OPTIONS = "2ab", - x_sigma_max: Optional[torch.Tensor] = None, - sigma_max: float | None = None, - guidance_other: Union[float, None] = None, - ) -> Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Args: - data_batch (dict): raw data batch draw from the training data loader. - iteration (int): Current iteration number. - guidance (float): guidance weights - seed (int): random seed - state_shape (tuple): shape of the state, default to self.state_shape if not provided - n_sample (int): number of samples to generate - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - num_steps (int): number of steps for the diffusion process - solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) - """ - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - if n_sample is None: - input_key = self.input_image_key if is_image_batch else self.input_data_key - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - if is_image_batch: - state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W - x0_fn = self.get_x0_fn_from_batch( - data_batch, guidance, is_negative_prompt=is_negative_prompt, guidance_other=guidance_other - ) - x_sigma_max = ( - misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - * self.sde.sigma_max - ) - if self.net.is_context_parallel_enabled: - x_sigma_max = rearrange(x_sigma_max, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - - x_sigma_max = split_inputs_cp(x=x_sigma_max, seq_dim=2, cp_group=self.net.cp_group) - - x_sigma_max = rearrange(x_sigma_max, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - samples = self.sampler( - x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max, solver_option=solver_option - ) - if self.net.is_context_parallel_enabled: - samples = rearrange(samples, "B C (V T) H W -> (B V) C T H W", V=self.n_views) - samples = cat_outputs_cp(samples, seq_dim=2, cp_group=self.net.cp_group) - samples = rearrange(samples, "(B V) C T H W -> B C (V T) H W", V=self.n_views) - - return samples - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - guidance_other: Union[float, None] = None, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - to_cp = self.net.is_context_parallel_enabled - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - condition = broadcast_condition(condition, to_tp=True, to_cp=to_cp) - uncondition = broadcast_condition(uncondition, to_tp=True, to_cp=to_cp) - else: - assert not to_cp, "parallel_state is not initialized, context parallel should be turned off." - - if guidance_other is not None: - # assume this is for inference time trajectory guidance for now - assert not parallel_state.is_initialized(), "Parallel state not supported with two guidances." - condition_other = copy.deepcopy(uncondition) - condition_other.trajectory = condition.trajectory - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise(noise_x, sigma, condition).x0 - uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 - cond_other_x0 = self.denoise(noise_x, sigma, condition_other).x0 - - raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) + guidance_other * (cond_other_x0 - uncond_x0) - - if "guided_image" in data_batch: - assert False, "not supported" - return raw_x0 - - else: - - def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - cond_x0 = self.denoise(noise_x, sigma, condition).x0 - uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0 - raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0) - if "guided_image" in data_batch: - # replacement trick that enables inpainting with base model - assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present" - guide_image = data_batch["guided_image"] - guide_mask = data_batch["guided_mask"] - raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0 - return raw_x0 - - return x0_fn - - -@diffusion_fsdp_class_decorator -class FSDPDiffusionModel(MultiviewDiffusionModel): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/attention.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/attention.py deleted file mode 100644 index a4febfe75d2bd3b0e5e87d97db516187af5b88d9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/attention.py +++ /dev/null @@ -1,486 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from contextlib import nullcontext -from typing import List, Optional, Union - -try: - from megatron.core import parallel_state - - USE_MEGATRON = True -except ImportError: - USE_MEGATRON = False - -import numpy as np -import torch -import torch.nn.functional as F -import transformer_engine as te -from einops import rearrange -from packaging import version -from torch import nn -from torch.nn.attention import SDPBackend -from torch.utils.checkpoint import checkpoint -from transformer_engine.pytorch.attention.dot_product_attention.dot_product_attention import DotProductAttention -from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb - -from cosmos_transfer1.utils import log - -# ---------------------- Feed Forward Network ----------------------- - - -class FeedForward(nn.Module): - """ - Transformer FFN with optional gating - - Parameters: - d_model (int): Dimensionality of input features. - d_ff (int): Dimensionality of the hidden layer. - dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1. - activation (callable, optional): The activation function applied after the first linear layer. - Defaults to nn.ReLU(). - is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer. - Defaults to False. - bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True. - - Example: - >>> ff = FeedForward(d_model=512, d_ff=2048) - >>> x = torch.randn(64, 10, 512) # Example input tensor - >>> output = ff(x) - >>> print(output.shape) # Expected shape: (64, 10, 512) - """ - - def __init__( - self, - d_model: int, - d_ff: int, - dropout: float = 0.1, - activation=nn.ReLU(), - is_gated: bool = False, - bias: bool = False, - ) -> None: - super().__init__() - - tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) - sequence_parallel = getattr(parallel_state, "sequence_parallel", False) - if tp_group is None: - tp_size = 1 # TP is not initialized. - else: - tp_size = parallel_state.get_tensor_model_parallel_world_size() - - if tp_size == 1: - self.layer1 = nn.Linear(d_model, d_ff, bias=bias) - self.layer2 = nn.Linear(d_ff, d_model, bias=bias) - else: - assert is_gated is False, "is_gated needs to be False to support Tensor Parallelism" - assert dropout == 0.0, "dropout needs to be 0.0 to support Tensor Parallelism" - self.layer1 = te.pytorch.Linear( - d_model, - d_ff, - bias=bias, - tp_size=tp_size, - tp_group=tp_group, - parallel_mode="column", - sequence_parallel=sequence_parallel, - ) - self.layer2 = te.pytorch.Linear( - d_ff, - d_model, - bias=bias, - tp_size=tp_size, - tp_group=tp_group, - parallel_mode="row", - sequence_parallel=sequence_parallel, - ) - - self.dropout = nn.Dropout(dropout) - self.activation = activation - self.is_gated = is_gated - if is_gated: - self.linear_gate = nn.Linear(d_model, d_ff, bias=False) - - def forward(self, x: torch.Tensor): - g = self.activation(self.layer1(x)) - if self.is_gated: - x = g * self.linear_gate(x) - else: - x = g - - assert self.dropout.p == 0.0, "skipping dropout to save memory" - return self.layer2(x) - - -class GPT2FeedForward(FeedForward): - def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False): - super().__init__( - d_model=d_model, - d_ff=d_ff, - dropout=dropout, - activation=nn.GELU(), - is_gated=False, - bias=bias, - ) - - def forward(self, x: torch.Tensor): - assert self.dropout.p == 0.0, "we skip dropout" - - x = self.layer1(x) - - def activation_layer2_forward(x): - x = self.activation(x) - x = self.layer2(x) - return x - - x = checkpoint(activation_layer2_forward, x, use_reentrant=False) - return x - - -# ---------------------- Normalization Layer ----------------------- - - -def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: - """ - Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. - - Args: - x (torch.Tensor): The input tensor to normalize. - dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. - eps (float, optional): A small constant to ensure numerical stability during division. - - Returns: - torch.Tensor: The normalized tensor. - """ - if dim is None: - dim = list(range(1, x.ndim)) - norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) - norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) - return x / norm.to(x.dtype) - - -def get_normalization(name: str, channels: int): - if name == "I": - return nn.Identity() - elif name == "R": - return te.pytorch.RMSNorm(channels, eps=1e-6) - else: - raise ValueError(f"Normalization {name} not found") - - -# ---------------------- Attention Op ----------------------- -# A list of attention ops -if version.parse(torch.__version__) >= version.parse("2.3.0"): - from torch.nn.attention import SDPBackend, sdpa_kernel - - sdpa_context = sdpa_kernel - USE_SDPA = True -elif version.parse(torch.__version__) >= version.parse("2.0.0"): - from torch.backends.cuda import SDPBackend, sdp_kernel - - sdpa_context = sdp_kernel - USE_SDPA = False -else: - sdpa_context = nullcontext - USE_SDPA = False - print( - f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, " - f"you are using PyTorch {torch.__version__}. You might want to consider upgrading." - ) - - -class BaseAttentionOp(nn.Module): - def __init__(self): - super().__init__() - - -class FusedAttentionOp(BaseAttentionOp): - def __init__(self): - super().__init__() - - -class TorchAttentionOp(FusedAttentionOp): - def __init__(self, backend: Optional[Union[List[SDPBackend], SDPBackend]] = SDPBackend.EFFICIENT_ATTENTION): - super().__init__() - self.backend = backend - self.sdpa_context = sdpa_context if self.backend is not None else nullcontext - if self.backend is not None: - log.warning( - "SDPA context manager is not working well with torch.compile, causing graph breaks and " - "significant slowdowns. If you are using torch.compile you'll most likely want to turn off " - "this context manager." - ) - - def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """ - Computes the scaled dot-product attention over the input tensors using the specified backend. - B is the batch size, M the sequence length, H the number of heads, and K the embeding size per head - - check F.scaled_dot_product_attention - Args: - q (Tensor): The query tensor of shape [B, Mq, H, K] / [B, ..., H, K] - k (Tensor): The key tensor of shape [B, Mk, H, V] / [B, ..., H, K] - v (Tensor): The value tensor of shape [B, Mk, H, V] / [B, ..., H, V] - - mask (Optional[Tensor]): An optional mask tensor. Follow scaled_dot_product_attention API, mask should be a boolean tensor with shape [B, H, Mq, Mk] - - Returns: - Tensor: [B, Mq, H, V] / [B, ..., H, V] - """ - in_q_shape = q.shape - in_k_shape = k.shape - q = rearrange(q, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) - k = rearrange(k, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - v = rearrange(v, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - if mask is not None: - assert mask.dtype == torch.bool, "Mask should be a boolean tensor" - with self.sdpa_context(self.backend): - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default - return rearrange(out, "b h ... l -> b ... h l").view(*in_q_shape[:-1], in_k_shape[-1]) - - -class Attention(nn.Module): - """ - Generalized attention impl. **With TP support**. - - Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided. - If `context_dim` is None, self-attention is assumed. - - Parameters: - query_dim (int): Dimension of each query vector. - context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed. - heads (int, optional): Number of attention heads. Defaults to 8. - dim_head (int, optional): Dimension of each head. Defaults to 64. - dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0. - attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default. - qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False. - out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False. - qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections. - Defaults to "SSI". - norm_args (dict, optional): Arguments to pass to the normalization function. Defaults to an empty dict. - - Examples: - >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1) - >>> query = torch.randn(10, 128) # Batch size of 10 - >>> context = torch.randn(10, 256) # Batch size of 10 - >>> output = attn(query, context) # Perform the attention operation - - Note: - https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - """ - - def __init__( - self, - query_dim: int, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - attn_op: Optional[BaseAttentionOp] = None, - qkv_bias: bool = False, - out_bias: bool = False, - qkv_norm: str = "SSI", - norm_args: dict = {}, - backend: str = "transformer_engine", - qkv_format: str = "bshd", - ) -> None: - super().__init__() - log.debug( - f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads with a dimension of {dim_head}. Norm options are {qkv_norm} and norm args are {norm_args}." - ) - self.is_selfattn = context_dim is None # self attention - - inner_dim = dim_head * heads - context_dim = query_dim if context_dim is None else context_dim - - self.heads = heads - self.dim_head = dim_head - self.qkv_format = qkv_format - norm_dim = dim_head - self.backend = backend - self.query_dim = query_dim - self.context_dim = context_dim - self.inner_dim = inner_dim - tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None - if tp_group is None: - self.tp_size = 1 # TP is not initialized. - else: - self.tp_size = parallel_state.get_tensor_model_parallel_world_size() - - if self.backend == "torch": - assert ( - self.tp_size == 1 - ), f"Attention backend {self.backend} cannot use TP size > 1. Attempted: {self.tp_size}" - - assert self.heads % self.tp_size == 0, "the number of heads should be divisible by TP size" - - if self.tp_size == 1: - self.to_q = nn.Sequential( - nn.Linear(query_dim, inner_dim, bias=qkv_bias), - get_normalization(qkv_norm[0], norm_dim, **norm_args), - ) - self.to_k = nn.Sequential( - nn.Linear(context_dim, inner_dim, bias=qkv_bias), - get_normalization(qkv_norm[1], norm_dim, **norm_args), - ) - self.to_v = nn.Sequential( - nn.Linear(context_dim, inner_dim, bias=qkv_bias), - get_normalization(qkv_norm[2], norm_dim, **norm_args), - ) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim, bias=out_bias), - nn.Dropout(dropout), - ) - else: # TP enabled. - sequence_parallel = getattr(parallel_state, "sequence_parallel", False) - if sequence_parallel: - assert qkv_format == "sbhd", "sequence parallel only supports sbhd format" - - self.to_q = nn.Sequential( - te.pytorch.Linear( - query_dim, - inner_dim, - bias=qkv_bias, - tp_size=self.tp_size, - tp_group=tp_group, - sequence_parallel=sequence_parallel, - parallel_mode="column", - ), - get_normalization(qkv_norm[0], norm_dim, **norm_args), - ) - self.to_k = nn.Sequential( - te.pytorch.Linear( - context_dim, - inner_dim, - bias=qkv_bias, - tp_size=self.tp_size, - tp_group=tp_group, - sequence_parallel=sequence_parallel, - parallel_mode="column", - ), - get_normalization(qkv_norm[1], norm_dim, **norm_args), - ) - self.to_v = nn.Sequential( - te.pytorch.Linear( - context_dim, - inner_dim, - bias=qkv_bias, - tp_size=self.tp_size, - tp_group=tp_group, - sequence_parallel=sequence_parallel, - parallel_mode="column", - ), - get_normalization(qkv_norm[2], norm_dim, **norm_args), - ) - - self.to_out = nn.Sequential( - te.pytorch.Linear( - inner_dim, - query_dim, - bias=out_bias, - tp_size=self.tp_size, - tp_group=tp_group, - parallel_mode="row", - sequence_parallel=sequence_parallel, - ), - nn.Dropout(dropout), - ) - - if attn_op: # use what is given - self.attn_op = attn_op - elif self.backend == "transformer_engine": - sequence_parallel = getattr(parallel_state, "sequence_parallel", False) if USE_MEGATRON else False - self.attn_op: BaseAttentionOp = DotProductAttention( - self.heads, - self.dim_head, - num_gqa_groups=self.heads, - attention_dropout=0, - qkv_format=qkv_format, - attn_mask_type="no_mask", - tp_size=self.tp_size, - tp_group=tp_group, - sequence_parallel=sequence_parallel, - ) - elif self.backend == "torch": - self.attn_op = TorchAttentionOp(None) - else: - raise ValueError(f"Backend {backend} not found") - - def cal_qkv( - self, x, context=None, mask=None, rope_emb=None, **kwargs - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - del kwargs - - """ - self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers. - Before 07/24/2024, these modules normalize across all heads. - After 07/24/2024, to support tensor parallelism and follow the common practice in the community, - we support to normalize per head. - To keep the checkpoint copatibility with the previous code, - we keep the nn.Sequential but call the projection and the normalization layers separately. - """ - - q = self.to_q[0](x) - context = x if context is None else context - k = self.to_k[0](context) - v = self.to_v[0](context) - q, k, v = map( - lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head), - (q, k, v), - ) - - def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb): - q = self.to_q[1](q) - k = self.to_k[1](k) - v = self.to_v[1](v) - if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) - k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) - return q, k, v - - q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False) - - return q, k, v - - def cal_attn(self, q, k, v, mask=None): - if self.backend == "transformer_engine": - seq_dim = self.qkv_format.index("s") - assert ( - q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 - ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." - out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] - return self.to_out(out) - elif self.backend == "torch": - out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V] - return self.to_out(rearrange(out, " b ... n c -> b ... (n c)")) - else: - raise ValueError(f"Backend {self.backend} not found") - - def forward( - self, - x, - context=None, - mask=None, - rope_emb=None, - **kwargs, - ): - """ - Args: - x (Tensor): The query tensor of shape [B, Mq, K] - context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None - """ - q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - return self.cal_attn(q, k, v, mask) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/blocks.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/blocks.py deleted file mode 100644 index e2325cc887baa130d56444aa9fd8432852772ebc..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/blocks.py +++ /dev/null @@ -1,1169 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional - -import torch -from einops import rearrange, repeat -from einops.layers.torch import Rearrange -from megatron.core import parallel_state -from torch import nn -from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb - -from cosmos_transfer1.diffusion.training.modules.attention import Attention, GPT2FeedForward -from cosmos_transfer1.diffusion.training.tensor_parallel import gather_along_first_dim -from cosmos_transfer1.utils import log - - -class SDXLTimesteps(nn.Module): - def __init__(self, num_channels: int = 320): - super().__init__() - self.num_channels = num_channels - - def forward(self, timesteps): - in_dype = timesteps.dtype - half_dim = self.num_channels // 2 - exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) - exponent = exponent / (half_dim - 0.0) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - sin_emb = torch.sin(emb) - cos_emb = torch.cos(emb) - emb = torch.cat([cos_emb, sin_emb], dim=-1) - - return emb.to(in_dype) - - -class SDXLTimestepEmbedding(nn.Module): - def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): - super().__init__() - log.critical( - f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." - ) - self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) - self.activation = nn.SiLU() - self.use_adaln_lora = use_adaln_lora - if use_adaln_lora: - self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) - else: - self.linear_2 = nn.Linear(out_features, out_features, bias=True) - - def forward(self, sample: torch.Tensor) -> torch.Tensor: - emb = self.linear_1(sample) - emb = self.activation(emb) - emb = self.linear_2(emb) - - if self.use_adaln_lora: - adaln_lora_B_3D = emb - emb_B_D = sample - else: - emb_B_D = emb - adaln_lora_B_3D = None - - return emb_B_D, adaln_lora_B_3D - - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - -class PatchEmbed(nn.Module): - """ - PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, - depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, - making it suitable for video and image processing tasks. It supports dividing the input into patches and embedding each - patch into a vector of size `out_channels`. - - Parameters: - - spatial_patch_size (int): The size of each spatial patch. - - temporal_patch_size (int): The size of each temporal patch. - - in_channels (int): Number of input channels. Default: 3. - - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. - - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. - - keep_spatio (bool): If True, the spatial dimensions are kept separate in the output tensor, otherwise, they are flattened. Default: False. - - legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! The legacy model is for backward compatibility. Default: True. - The output shape of the module depends on the `keep_spatio` flag. If `keep_spatio`=True, the output retains the spatial dimensions. - Otherwise, the spatial dimensions are flattened into a single dimension. - """ - - def __init__( - self, - spatial_patch_size, - temporal_patch_size, - in_channels=3, - out_channels=768, - bias=True, - keep_spatio=False, - legacy_patch_emb: bool = True, - ): - super().__init__() - self.spatial_patch_size = spatial_patch_size - self.temporal_patch_size = temporal_patch_size - assert keep_spatio, "Only support keep_spatio=True" - self.keep_spatio = keep_spatio - self.legacy_patch_emb = legacy_patch_emb - - if legacy_patch_emb: - self.proj = nn.Conv3d( - in_channels, - out_channels, - kernel_size=(temporal_patch_size, spatial_patch_size, spatial_patch_size), - stride=(temporal_patch_size, spatial_patch_size, spatial_patch_size), - bias=bias, - ) - self.out = Rearrange("b c t h w -> b t h w c") - else: - self.proj = nn.Sequential( - Rearrange( - "b c (t r) (h m) (w n) -> b t h w (c r m n)", - r=temporal_patch_size, - m=spatial_patch_size, - n=spatial_patch_size, - ), - nn.Linear( - in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias - ), - ) - self.out = nn.Identity() - - def forward(self, x): - """ - Forward pass of the PatchEmbed module. - - Parameters: - - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where - B is the batch size, - C is the number of channels, - T is the temporal dimension, - H is the height, and - W is the width of the input. - - Returns: - - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. - """ - assert x.dim() == 5 - _, _, T, H, W = x.shape - assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 - assert T % self.temporal_patch_size == 0 - x = self.proj(x) - return self.out(x) - - -class ExtraTokenPatchEmbed(PatchEmbed): - def __init__(self, *args, out_channels: int = 768, keep_spatio: bool = False, **kwargs): - assert keep_spatio, "ExtraTokenPatchEmbed only supports keep_spatio=True" - super().__init__(*args, out_channels=out_channels, keep_spatio=keep_spatio, **kwargs) - self.temporal_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) - self.spatial_token = nn.Parameter(torch.randn(1, 1, 1, 1, out_channels)) - - def forward(self, x): - x_B_T_H_W_C = super().forward(x) - B, T, H, W, C = x_B_T_H_W_C.shape - x_B_T_H_W_C = torch.cat( - [ - x_B_T_H_W_C, - self.temporal_token.repeat(B, 1, H, W, 1), - ], - dim=1, - ) - x_B_T_H_W_C = torch.cat( - [ - x_B_T_H_W_C, - self.spatial_token.repeat(B, T, H, 1, 1), - ], - dim=3, - ) - return x_B_T_H_W_C - - -class ExpertChoiceMoEGate(nn.Module): - """ - ExpertChoiceMoEGate determines which tokens go - to which experts (and how much to weigh each expert). - - Args: - hidden_size (int): Dimensionality of input features. - num_experts (int): Number of experts (E). - capacity (int): Capacity (number of tokens) each expert can process (C). - """ - - def __init__( - self, - hidden_size: int, - num_experts: int, - capacity: int, - ): - super().__init__() - self.hidden_size = hidden_size - self.num_experts = num_experts - self.capacity = capacity - - self.router = nn.Parameter(torch.empty((self.num_experts, self.hidden_size))) - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.kaiming_uniform_(self.router) - - def forward(self, x: torch.Tensor): - """ - Args: - x (Tensor): Input of shape (B, S, D) - Returns: - gating (Tensor): Gating weights of shape (B, E, C), - where E = num_experts, C = capacity (top-k). - dispatch (Tensor): Dispatch mask of shape (B, E, C, S). - index (Tensor): Indices of top-k tokens for each expert, - shape (B, E, C). - """ - B, S, D = x.shape - E, C = self.num_experts, self.capacity - - # token-expert affinity scores - logits = torch.einsum("bsd,de->bse", x, self.router) - affinity = torch.nn.functional.softmax(logits, dim=-1) # (B, S, E) - - # gather topk tokens for each expert - affinity_t = affinity.transpose(1, 2) # (B, E, S) - - # select top-k tokens for each expert - gating, index = torch.topk(affinity_t, k=C, dim=-1) # (B, E, C) - - # one-hot dispatch mask - dispatch = torch.nn.functional.one_hot(index, num_classes=S).float() # (B, E, C, S) - - return gating, dispatch, index - - -class ExpertChoiceMoELayer(nn.Module): - """ - ExpertChoiceMoELayer uses the ExpertChoiceMoEGate to route tokens - to experts, process them, and then combine the outputs. - - Args: - gate_hidden_size (int): Dimensionality of input features. - ffn_hidden_size (int): Dimension of hidden layer in each expert feedforward (e.g., GPT2FeedForward). - num_experts (int): Number of experts (E). - capacity (int): Capacity (number of tokens) each expert can process (C). - expert_cls (nn.Module): The class to instantiate each expert. Defaults to GPT2FeedForward. - expert_kwargs (dict): Extra kwargs to pass to each expert class. - """ - - def __init__( - self, - gate_hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - capacity: int, - expert_class: nn.Module = GPT2FeedForward, - expert_kwargs=None, - ): - super().__init__() - if not expert_kwargs: - expert_kwargs = {} - - self.gate_hidden_size = gate_hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.num_experts = num_experts - self.capacity = capacity - - self.gate = ExpertChoiceMoEGate(gate_hidden_size, num_experts, capacity) - - self.experts = nn.ModuleList( - [expert_class(gate_hidden_size, ffn_hidden_size, **expert_kwargs) for _ in range(num_experts)] - ) - - def forward(self, x: torch.Tensor): - """ - Args: - x (Tensor): Input of shape (B, S, D). - - Returns: - x_out (Tensor): Output of shape (B, S, D), after dispatching tokens - to experts and combining their outputs. - """ - B, S, D = x.shape - E, C = self.num_experts, self.capacity - - # gating: (B, E, C) - # dispatch: (B, E, C, S) - gating, dispatch, index = self.gate(x) - - # collect input tokens for each expert - x_in = torch.einsum("becs,bsd->becd", dispatch, x) - - # process through each expert - expert_outputs = [self.experts[e](x_in[:, e]) for e in range(E)] - - x_e = torch.stack(expert_outputs, dim=1) # (B, E, C, D) - - # gating: (B, E, C), dispatch: (B, E, C, S), x_e: (B, E, C, d) - # x_out: (B, S, D) - # each token is placed back to its location with weighting - x_out = torch.einsum("becs,bec,becd->bsd", dispatch, gating, x_e) - - return x_out - - -class FinalLayer(nn.Module): - """ - The final layer of video DiT. - """ - - def __init__( - self, - hidden_size, - spatial_patch_size, - temporal_patch_size, - out_channels, - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - ): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear( - hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False - ) - self.hidden_size = hidden_size - self.n_adaln_chunks = 2 - self.use_adaln_lora = use_adaln_lora - if use_adaln_lora: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, adaln_lora_dim, bias=False), - nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), - ) - else: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) - ) - - self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) - - def forward( - self, - x_BT_HW_D, - emb_B_D, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - ): - if self.use_adaln_lora: - assert adaln_lora_B_3D is not None - shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( - 2, dim=1 - ) - else: - shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) - - B = emb_B_D.shape[0] - T = x_BT_HW_D.shape[0] // B - shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) - x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) - if self.sequence_parallel: - x_T_B_HW_D = rearrange(x_BT_HW_D, "(b t) hw d -> t b hw d", b=B, t=T) - x_T_B_HW_D = gather_along_first_dim(x_T_B_HW_D, parallel_state.get_tensor_model_parallel_group()) - x_BT_HW_D = rearrange(x_T_B_HW_D, "t b hw d -> (b t) hw d", b=B) - - x_BT_HW_D = self.linear(x_BT_HW_D) - return x_BT_HW_D - - def forward_with_memory_save( - self, - x_BT_HW_D_before_gate: torch.Tensor, - x_BT_HW_D_skip: torch.Tensor, - gate_L_B_D: torch.Tensor, - emb_B_D, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - ): - if self.use_adaln_lora: - assert adaln_lora_B_3D is not None - shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( - 2, dim=1 - ) - else: - shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) - - B = emb_B_D.shape[0] - T = x_BT_HW_D_before_gate.shape[0] // B - shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) - gate_BT_1_D = repeat(gate_L_B_D, "1 b d -> (b t) 1 d", t=T) - - def _fn(_x_before_gate, _x_skip): - previous_block_out = _x_skip + gate_BT_1_D * _x_before_gate - _x = modulate(self.norm_final(previous_block_out), shift_BT_D, scale_BT_D) - return self.linear(_x) - - return torch.utils.checkpoint.checkpoint(_fn, x_BT_HW_D_before_gate, x_BT_HW_D_skip, use_reentrant=False) - - -class VideoAttn(nn.Module): - """ - Implements video attention with optional cross-attention capabilities. - - This module supports both self-attention within the video frames and cross-attention - with an external context. It's designed to work with flattened spatial dimensions - to accommodate for video input. - - Attributes: - x_dim (int): Dimensionality of the input feature vectors. - context_dim (Optional[int]): Dimensionality of the external context features. - If None, the attention does not utilize external context. - num_heads (int): Number of attention heads. - bias (bool): If true, bias is added to the query, key, value projections. - x_format (str): The shape format of x tenosor. - n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. - """ - - def __init__( - self, - x_dim: int, - context_dim: Optional[int], - num_heads: int, - bias: bool = False, - x_format: str = "BTHWD", - n_views: int = 1, - ) -> None: - super().__init__() - self.n_views = n_views - self.x_format = x_format - if self.x_format == "BTHWD": - qkv_format = "bshd" - elif self.x_format == "THWBD": - qkv_format = "sbhd" - else: - raise NotImplementedError(f"Unsupported x_format: {self.x_format}") - - self.attn = Attention( - x_dim, - context_dim, - num_heads, - x_dim // num_heads, - qkv_bias=bias, - qkv_norm="RRI", - out_bias=bias, - qkv_format=qkv_format, - ) - - def forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass for video attention. - - Args: - x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. - context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), where M is the sequence length of the context. - crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. - rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format - - Returns: - Tensor: The output tensor with applied attention, maintaining the input shape. - """ - - if self.x_format == "BTHWD": - if context is not None and self.n_views > 1: - x_B_T_H_W_D = rearrange(x, "b (v t) h w d -> (v b) t h w d", v=self.n_views) - context_B_M_D = rearrange(context, "b (v m) d -> (v b) m d", v=self.n_views) - else: - x_B_T_H_W_D = x - context_B_M_D = context - B, T, H, W, D = x_B_T_H_W_D.shape - x_B_THW_D = rearrange(x_B_T_H_W_D, "b t h w d -> b (t h w) d") - x_B_THW_D = self.attn(x_B_THW_D, context_B_M_D, crossattn_mask, rope_emb=rope_emb_L_1_1_D) - - # reshape it back to video format - x_B_T_H_W_D = rearrange(x_B_THW_D, "b (t h w) d -> b t h w d", h=H, w=W) - if context is not None and self.n_views > 1: - x_B_T_H_W_D = rearrange(x_B_T_H_W_D, "(v b) t h w d -> b (v t) h w d", v=self.n_views) - return x_B_T_H_W_D - elif self.x_format == "THWBD": - if context is not None and self.n_views > 1: - x_T_H_W_B_D = rearrange(x, "(v t) h w b d -> t h w (v b) d", v=self.n_views) - context_M_B_D = rearrange(context, "(v m) b d -> m (v b) d", v=self.n_views) - else: - x_T_H_W_B_D = x - context_M_B_D = context - T, H, W, B, D = x_T_H_W_B_D.shape - x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") - x_THW_B_D = self.attn( - x_THW_B_D, - context_M_B_D, - crossattn_mask, - rope_emb=rope_emb_L_1_1_D, - ) - x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) - if context is not None and self.n_views > 1: - x_T_H_W_B_D = rearrange(x_T_H_W_B_D, "t h w (v b) d -> (v t) h w b d", v=self.n_views) - return x_T_H_W_B_D - else: - raise NotImplementedError(f"Unsupported x_format: {self.x_format}") - - -def checkpoint_norm_state(norm_state, x, scale, shift): - normalized = norm_state(x) - return normalized * (1 + scale) + shift - - -class DITBuildingBlock(nn.Module): - """ - DIT Building Block for constructing various types of attention or MLP blocks dynamically based on a specified block type. - - This class instantiates different types of buildig block / attn and MLP based on config, and applies crossponding forward pass during training. - - Attributes: - block_type (str): Type of block to be used ('spatial_sa', 'temporal_sa', 'cross_attn', 'full_attn', 'mlp'). - x_dim (int): Dimensionality of the input features. - context_dim (Optional[int]): Dimensionality of the external context, required for cross attention blocks. - num_heads (int): Number of attention heads. - mlp_ratio (float): Multiplier for the dimensionality of the MLP hidden layer compared to input. - spatial_win_size (int): Window size for spatial self-attention. - temporal_win_size (int): Window size for temporal self-attention. - bias (bool): Whether to include bias in attention and MLP computations. - mlp_dropout (float): Dropout rate for MLP blocks. - n_views (int): Extra parameter used in multi-view diffusion model. It indicated total number of view we model together. - """ - - def __init__( - self, - block_type: str, - x_dim: int, - context_dim: Optional[int], - num_heads: int, - mlp_ratio: float = 4.0, - window_sizes: list = [], - spatial_win_size: int = 1, - temporal_win_size: int = 1, - bias: bool = False, - mlp_dropout: float = 0.0, - x_format: str = "BTHWD", - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - n_views: int = 1, - ) -> None: - block_type = block_type.lower() - - super().__init__() - self.x_format = x_format - if block_type in ["cross_attn", "ca"]: - self.block = VideoAttn( - x_dim, - context_dim, - num_heads, - bias=bias, - x_format=self.x_format, - n_views=n_views, - ) - elif block_type in ["full_attn", "fa"]: - self.block = VideoAttn(x_dim, None, num_heads, bias=bias, x_format=self.x_format) - elif block_type in ["mlp", "ff"]: - self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias) - else: - raise ValueError(f"Unknown block type: {block_type}") - - self.block_type = block_type - self.use_adaln_lora = use_adaln_lora - - self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) - self.n_adaln_chunks = 3 - if use_adaln_lora: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(x_dim, adaln_lora_dim, bias=False), - nn.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False), - ) - else: - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False)) - - def forward_with_attn_memory_save( - self, - x_before_gate: torch.Tensor, - x_skip: torch.Tensor, - gate_L_B_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ): - del crossattn_mask - assert isinstance(self.block, VideoAttn), "only support VideoAttn impl" - if self.use_adaln_lora: - shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( - self.n_adaln_chunks, dim=1 - ) - else: - shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) - - shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( - shift_B_D.unsqueeze(0), - scale_B_D.unsqueeze(0), - gate_B_D.unsqueeze(0), - ) - - def _fn(_x_before_gate, _x_skip, _context): - previous_block_out = _x_skip + gate_L_B_D * _x_before_gate - if extra_per_block_pos_emb is not None: - previous_block_out = previous_block_out + extra_per_block_pos_emb - _normalized_x = self.norm_state(previous_block_out) - normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D - # context = normalized_x if _context is None else _context - context = normalized_x if self.block.attn.is_selfattn else _context - return ( - self.block.attn.to_q[0](normalized_x), - self.block.attn.to_k[0](context), - self.block.attn.to_v[0](context), - previous_block_out, - ) - - q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( - _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False - ) - - def attn_fn(_q, _k, _v): - q, k, v = map( - lambda t: rearrange( - t, - "b ... (n c) -> b ... n c", - n=self.block.attn.heads // self.block.attn.tp_size, - c=self.block.attn.dim_head, - ), - (_q, _k, _v), - ) - q = self.block.attn.to_q[1](q) - k = self.block.attn.to_k[1](k) - v = self.block.attn.to_v[1](v) - if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) - k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) - - if self.block.attn.is_selfattn: - return q, k, v - - seq_dim = self.block.attn.qkv_format.index("s") - assert ( - q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 - ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." - if regional_contexts is not None and region_masks is not None: - return self.block.attn.regional_attn_op( - q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None - ) # [B, Mq, H, V] - else: - return self.block.attn.attn_op( - q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None - ) # [B, Mq, H, V] - - assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." - - if self.block.attn.is_selfattn: - q, k, v = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) - seq_dim = self.block.attn.qkv_format.index("s") - assert ( - q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 - ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." - softmax_attn_output = self.block.attn.attn_op( - q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None - ) # [B, Mq, H, V] - else: - softmax_attn_output = torch.utils.checkpoint.checkpoint(attn_fn, q, k, v, use_reentrant=False) - attn_out = self.block.attn.to_out(softmax_attn_output) - return _gate_L_B_D, attn_out, previous_block_out - - def forward_with_x_attn_memory_save( - self, - x_before_gate: torch.Tensor, - x_skip: torch.Tensor, - gate_L_B_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ): - del crossattn_mask - assert isinstance(self.block, VideoAttn) - if self.use_adaln_lora: - shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( - self.n_adaln_chunks, dim=1 - ) - else: - shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) - - shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( - shift_B_D.unsqueeze(0), - scale_B_D.unsqueeze(0), - gate_B_D.unsqueeze(0), - ) - - def _fn(_x_before_gate, _x_skip, _context): - previous_block_out = _x_skip + gate_L_B_D * _x_before_gate - if extra_per_block_pos_emb is not None: - previous_block_out = previous_block_out + extra_per_block_pos_emb - _normalized_x = self.norm_state(previous_block_out) - normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D - # context = normalized_x if _context is None else _context - context = normalized_x if self.block.attn.is_selfattn else _context - return ( - self.block.attn.to_q[0](normalized_x), - self.block.attn.to_k[0](context), - self.block.attn.to_v[0](context), - previous_block_out, - ) - - q, k, v, previous_block_out = torch.utils.checkpoint.checkpoint( - _fn, x_before_gate, x_skip, crossattn_emb, use_reentrant=False - ) - - def x_attn_fn(_q, _k, _v): - q, k, v = map( - lambda t: rearrange( - t, - "b ... (n c) -> b ... n c", - n=self.block.attn.heads // self.block.attn.tp_size, - c=self.block.attn.dim_head, - ), - (_q, _k, _v), - ) - q = self.block.attn.to_q[1](q) - k = self.block.attn.to_k[1](k) - v = self.block.attn.to_v[1](v) - if self.block.attn.is_selfattn and rope_emb_L_1_1_D is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) - k = apply_rotary_pos_emb(k, rope_emb_L_1_1_D, tensor_format=self.block.attn.qkv_format, fused=True) - - seq_dim = self.block.attn.qkv_format.index("s") - assert ( - q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 - ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." - if regional_contexts is not None and region_masks is not None: - softmax_attn_output = self.block.attn.regional_attn_op( - q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None - ) - return self.block.attn.to_out(softmax_attn_output) - else: - softmax_attn_output = self.block.attn.attn_op( - q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None - ) - return self.block.attn.to_out(softmax_attn_output) - - assert self.block.attn.backend == "transformer_engine", "Only support transformer_engine backend for now." - - attn_out = torch.utils.checkpoint.checkpoint(x_attn_fn, q, k, v, use_reentrant=False) - return _gate_L_B_D, attn_out, previous_block_out - - def forward_with_ffn_memory_save( - self, - x_before_gate: torch.Tensor, - x_skip: torch.Tensor, - gate_L_B_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ): - del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, regional_contexts, region_masks - assert isinstance(self.block, GPT2FeedForward) - if self.use_adaln_lora: - shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( - self.n_adaln_chunks, dim=1 - ) - else: - shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) - - shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( - shift_B_D.unsqueeze(0), - scale_B_D.unsqueeze(0), - gate_B_D.unsqueeze(0), - ) - - def _fn(_x_before_gate, _x_skip): - previous_block_out = _x_skip + gate_L_B_D * _x_before_gate - if extra_per_block_pos_emb is not None: - previous_block_out = previous_block_out + extra_per_block_pos_emb - _normalized_x = self.norm_state(previous_block_out) - normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D - - assert self.block.dropout.p == 0.0, "we skip dropout to save memory" - - return self.block.layer1(normalized_x), previous_block_out - - intermediate_output, previous_block_out = torch.utils.checkpoint.checkpoint( - _fn, x_before_gate, x_skip, use_reentrant=False - ) - - def _fn2(_x): - _x = self.block.activation(_x) - return self.block.layer2(_x) - - return ( - _gate_L_B_D, - torch.utils.checkpoint.checkpoint(_fn2, intermediate_output, use_reentrant=False), - previous_block_out, - ) - - def forward_with_ffn_memory_save_upgrade( - self, - x_before_gate: torch.Tensor, - x_skip: torch.Tensor, - gate_L_B_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ): - del crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, regional_contexts, region_masks - assert isinstance(self.block, GPT2FeedForward) - if self.use_adaln_lora: - shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( - self.n_adaln_chunks, dim=1 - ) - else: - shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) - - shift_L_B_D, scale_L_B_D, _gate_L_B_D = ( - shift_B_D.unsqueeze(0), - scale_B_D.unsqueeze(0), - gate_B_D.unsqueeze(0), - ) - - def _fn2(_x): - _x = self.block.activation(_x) - return self.block.layer2(_x) - - def _fn(_x_before_gate, _x_skip): - previous_block_out = _x_skip + gate_L_B_D * _x_before_gate - if extra_per_block_pos_emb is not None: - previous_block_out = previous_block_out + extra_per_block_pos_emb - _normalized_x = self.norm_state(previous_block_out) - normalized_x = _normalized_x * (1 + scale_L_B_D) + shift_L_B_D - - assert self.block.dropout.p == 0.0, "we skip dropout to save memory" - - return _fn2(self.block.layer1(normalized_x)), previous_block_out - - output, previous_block_out = torch.utils.checkpoint.checkpoint(_fn, x_before_gate, x_skip, use_reentrant=False) - - return ( - _gate_L_B_D, - output, - previous_block_out, - ) - - def forward_with_memory_save( - self, - x_before_gate: torch.Tensor, - x_skip: torch.Tensor, - gate_L_B_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ): - if isinstance(self.block, VideoAttn): - if self.block.attn.is_selfattn: - fn = self.forward_with_attn_memory_save - else: - fn = self.forward_with_x_attn_memory_save - else: - # fn = self.forward_with_ffn_memory_save - fn = self.forward_with_ffn_memory_save_upgrade - return fn( - x_before_gate, - x_skip, - gate_L_B_D, - emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_per_block_pos_emb, - ) - - def forward( - self, - x: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass for dynamically configured blocks with adaptive normalization. - - Args: - x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). - emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. - crossattn_emb (Tensor): Tensor for cross-attention blocks. - crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. - rope_emb_L_1_1_D (Optional[Tensor]): Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. transformer_engine format - adaln_lora_B_3D (Optional[Tensor]): Additional embedding for adaptive layer norm. - regional_contexts (Optional[List[Tensor]]): List of regional context tensors. - region_masks (Optional[Tensor]): Region masks of shape (B, R, THW). - - Returns: - Tensor: The output tensor after processing through the configured block and adaptive normalization. - """ - if self.use_adaln_lora: - shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( - self.n_adaln_chunks, dim=1 - ) - else: - shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) - - if self.x_format == "BTHWD": - shift_B_1_1_1_D, scale_B_1_1_1_D, gate_B_1_1_1_D = ( - shift_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), - scale_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), - gate_B_D.unsqueeze(1).unsqueeze(2).unsqueeze(3), - ) - if self.block_type in ["spatial_sa", "temporal_sa", "window_attn", "ssa", "tsa", "wa"]: - x = x + gate_B_1_1_1_D * self.block( - self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - ) - elif self.block_type in ["full_attn", "fa"]: - x = x + gate_B_1_1_1_D * self.block( - self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, - context=None, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - ) - elif self.block_type in ["cross_attn", "ca"]: - normalized_x = self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D - x = x + gate_B_1_1_1_D * self.block( - normalized_x, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - elif self.block_type in ["mlp", "ff"]: - x = x + gate_B_1_1_1_D * self.block( - self.norm_state(x) * (1 + scale_B_1_1_1_D) + shift_B_1_1_1_D, - ) - else: - raise ValueError(f"Unknown block type: {self.block_type}") - elif self.x_format == "THWBD": - shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( - shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), - scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), - gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), - ) - - if self.block_type in ["mlp", "ff"]: - x = x + gate_1_1_1_B_D * self.block( - torch.utils.checkpoint.checkpoint( - checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False - ), - ) - elif self.block_type in ["full_attn", "fa"]: - x = x + gate_1_1_1_B_D * self.block( - torch.utils.checkpoint.checkpoint( - checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False - ), - context=None, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - ) - elif self.block_type in ["cross_attn", "ca"]: - x = x + gate_1_1_1_B_D * self.block( - torch.utils.checkpoint.checkpoint( - checkpoint_norm_state, self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D, use_reentrant=False - ), - context=crossattn_emb, - crossattn_mask=crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - ) - else: - raise ValueError(f"Unknown block type: {self.block_type}") - else: - raise NotImplementedError(f"Unsupported x_format: {self.x_format}") - return x - - -class GeneralDITTransformerBlock(nn.Module): - """ - This class is a wrapper for a list of DITBuildingBlock. - It's not essential, refactor it if needed. - """ - - def __init__( - self, - x_dim: int, - context_dim: int, - num_heads: int, - block_config: str, - mlp_ratio: float = 4.0, - window_sizes: list = [], - spatial_attn_win_size: int = 1, - temporal_attn_win_size: int = 1, - use_checkpoint: bool = False, - x_format: str = "BTHWD", - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - n_views: int = 1, - ): - super().__init__() - self.blocks = nn.ModuleList() - self.x_format = x_format - for block_type in block_config.split("-"): - self.blocks.append( - DITBuildingBlock( - block_type, - x_dim, - context_dim, - num_heads, - mlp_ratio, - window_sizes, - spatial_attn_win_size, - temporal_attn_win_size, - x_format=self.x_format, - use_adaln_lora=use_adaln_lora, - adaln_lora_dim=adaln_lora_dim, - n_views=n_views, - ) - ) - self.use_checkpoint = use_checkpoint - - def forward( - self, - x: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint( - self._forward, - x, - emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_per_block_pos_emb, - regional_contexts, - region_masks, - use_reentrant=False, - ) - else: - return self._forward( - x, - emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_per_block_pos_emb, - regional_contexts, - region_masks, - ) - - def _forward( - self, - x: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if extra_per_block_pos_emb is not None: - x = x + extra_per_block_pos_emb - for block in self.blocks: - x = block( - x, - emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - return x - - def set_memory_save(self, mode: bool = True): - # to make fsdp happy! - #! IMPORTANT! - if mode: - self.forward = self.forward_with_memory_save - for block in self.blocks: - block.forward = block.forward_with_memory_save - else: - raise NotImplementedError("Not implemented yet.") - - def forward_with_memory_save( - self, - x_before_gate: torch.Tensor, - x_skip: torch.Tensor, - gate_L_B_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - rope_emb_L_1_1_D: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, - regional_contexts: Optional[torch.Tensor] = None, - region_masks: Optional[torch.Tensor] = None, - ): - for block in self.blocks: - gate_L_B_D, x_before_gate, x_skip = block.forward( - x_before_gate, - x_skip, - gate_L_B_D, - emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_per_block_pos_emb, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - extra_per_block_pos_emb = None - return gate_L_B_D, x_before_gate, x_skip diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/edm_sde.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/edm_sde.py deleted file mode 100644 index 3d08a8229f03c9fdd6a8d905ad4543fe5fe5238a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/edm_sde.py +++ /dev/null @@ -1,43 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from statistics import NormalDist - -import numpy as np -import torch - - -class EDMSDE: - def __init__( - self, - p_mean: float = -1.2, - p_std: float = 1.2, - sigma_max: float = 80.0, - sigma_min: float = 0.002, - ): - self.gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) - self.sigma_max = sigma_max - self.sigma_min = sigma_min - - def sample_t(self, batch_size: int) -> torch.Tensor: - cdf_vals = np.random.uniform(size=(batch_size)) - samples_interval_gaussian = [self.gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] - - log_sigma = torch.tensor(samples_interval_gaussian, device="cuda") - return torch.exp(log_sigma) - - def marginal_prob(self, x0: torch.Tensor, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """This is trivial in the base class, but may be used by derived classes in a more interesting way""" - return x0, sigma diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py deleted file mode 100644 index d55a8d670a889d2fb8ce1602c96f25b6e613bed0..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/modules/pretrained_vae.py +++ /dev/null @@ -1,736 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from abc import ABC, abstractmethod - -import torch -import torch.nn.functional as F -from einops import rearrange -from torch.nn.modules import Module - -from cosmos_transfer1.utils.distributed import rank0_first - - -class BaseVAE(torch.nn.Module, ABC): - """ - Abstract base class for a Variational Autoencoder (VAE). - - All subclasses should implement the methods to define the behavior for encoding - and decoding, along with specifying the latent channel size. - """ - - def __init__(self, channel: int = 3, name: str = "vae"): - super().__init__() - self.channel = channel - self.name = name - - @property - def latent_ch(self) -> int: - """ - Returns the number of latent channels in the VAE. - """ - return self.channel - - @abstractmethod - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - Encodes the input tensor into a latent representation. - - Args: - - state (torch.Tensor): The input tensor to encode. - - Returns: - - torch.Tensor: The encoded latent tensor. - """ - pass - - @abstractmethod - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decodes the latent representation back to the original space. - - Args: - - latent (torch.Tensor): The latent tensor to decode. - - Returns: - - torch.Tensor: The decoded tensor. - """ - pass - - @property - def spatial_compression_factor(self) -> int: - """ - Returns the spatial reduction factor for the VAE. - """ - raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") - - -class VideoTokenizerInterface(ABC): - @abstractmethod - def encode(self, state: torch.Tensor) -> torch.Tensor: - pass - - @abstractmethod - def decode(self, latent: torch.Tensor) -> torch.Tensor: - pass - - @abstractmethod - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - pass - - @abstractmethod - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - pass - - @property - @abstractmethod - def spatial_compression_factor(self): - pass - - @property - @abstractmethod - def temporal_compression_factor(self): - pass - - @property - @abstractmethod - def spatial_resolution(self): - pass - - @property - @abstractmethod - def pixel_chunk_duration(self): - pass - - @property - @abstractmethod - def latent_chunk_duration(self): - pass - - @property - def is_chunk_overlap(self): - return False - - -class BasePretrainedVideoTokenizer(ABC): - """ - Base class for a pretrained video tokenizer that handles chunking of video data for efficient processing. - - Args: - pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level. - temporal_compress_factor (int): The factor by which the video data is temporally compressed during processing. - max_enc_batch_size (int): The maximum batch size to process in one go during encoding to avoid memory overflow. - max_dec_batch_size (int): The maximum batch size to process in one go during decoding to avoid memory overflow. - - The class introduces parameters for managing temporal chunks (`pixel_chunk_duration` and `temporal_compress_factor`) - which define how video data is subdivided and compressed during the encoding and decoding processes. The - `max_enc_batch_size` and `max_dec_batch_size` parameters allow processing in smaller batches to handle memory - constraints. - """ - - def __init__( - self, - pixel_chunk_duration: int = 17, - temporal_compress_factor: int = 8, - max_enc_batch_size: int = 8, - max_dec_batch_size: int = 4, - ): - self._pixel_chunk_duration = pixel_chunk_duration - self._temporal_compress_factor = temporal_compress_factor - self.max_enc_batch_size = max_enc_batch_size - self.max_dec_batch_size = max_dec_batch_size - - def register_mean_std(self, mean_std_fp: str) -> None: - latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) - latent_mean = latent_mean.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] - latent_std = latent_std.view(self.latent_ch, -1)[:, : self.latent_chunk_duration] - - target_shape = [1, self.latent_ch, self.latent_chunk_duration, 1, 1] - - self.register_buffer( - "latent_mean", - latent_mean.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - self.register_buffer( - "latent_std", - latent_std.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - - def transform_encode_state_shape(self, state: torch.Tensor) -> torch.Tensor: - """ - Rearranges the input state tensor to the required shape for encoding video data. Mainly for chunk based encoding - """ - B, C, T, H, W = state.shape - assert ( - T % self.pixel_chunk_duration == 0 - ), f"Temporal dimension {T} is not divisible by chunk_length {self.pixel_chunk_duration}" - return rearrange(state, "b c (n t) h w -> (b n) c t h w", t=self.pixel_chunk_duration) - - def transform_decode_state_shape(self, latent: torch.Tensor) -> None: - B, _, T, _, _ = latent.shape - assert ( - T % self.latent_chunk_duration == 0 - ), f"Temporal dimension {T} is not divisible by chunk_length {self.latent_chunk_duration}" - return rearrange(latent, "b c (n t) h w -> (b n) c t h w", t=self.latent_chunk_duration) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - if self._temporal_compress_factor == 1: - _, _, origin_T, _, _ = state.shape - state = rearrange(state, "b c t h w -> (b t) c 1 h w") - B, C, T, H, W = state.shape - state = self.transform_encode_state_shape(state) - # use max_enc_batch_size to avoid OOM - if state.shape[0] > self.max_enc_batch_size: - latent = [] - for i in range(0, state.shape[0], self.max_enc_batch_size): - latent.append(super().encode(state[i : i + self.max_enc_batch_size])) - latent = torch.cat(latent, dim=0) - else: - latent = super().encode(state) - - latent = rearrange(latent, "(b n) c t h w -> b c (n t) h w", b=B) - if self._temporal_compress_factor == 1: - latent = rearrange(latent, "(b t) c 1 h w -> b c t h w", t=origin_T) - return latent - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decodes a batch of latent representations into video frames by applying temporal chunking. Similar to encode, - it handles video data by processing smaller temporal chunks to reconstruct the original video dimensions. - - It can also decode single frame image data. - - Args: - latent (torch.Tensor): The latent space tensor containing encoded video data. - - Returns: - torch.Tensor: The decoded video tensor reconstructed from latent space. - """ - if self._temporal_compress_factor == 1: - _, _, origin_T, _, _ = latent.shape - latent = rearrange(latent, "b c t h w -> (b t) c 1 h w") - B, _, T, _, _ = latent.shape - latent = self.transform_decode_state_shape(latent) - # use max_enc_batch_size to avoid OOM - if latent.shape[0] > self.max_dec_batch_size: - state = [] - for i in range(0, latent.shape[0], self.max_dec_batch_size): - state.append(super().decode(latent[i : i + self.max_dec_batch_size])) - state = torch.cat(state, dim=0) - else: - state = super().decode(latent) - assert state.shape[2] == self.pixel_chunk_duration - state = rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B) - if self._temporal_compress_factor == 1: - return rearrange(state, "(b t) c 1 h w -> b c t h w", t=origin_T) - return state - - @property - def pixel_chunk_duration(self) -> int: - return self._pixel_chunk_duration - - @property - def latent_chunk_duration(self) -> int: - # return self._latent_chunk_duration - assert (self.pixel_chunk_duration - 1) % self.temporal_compression_factor == 0, ( - f"Pixel chunk duration {self.pixel_chunk_duration} is not divisible by latent chunk duration " - f"{self.latent_chunk_duration}" - ) - return (self.pixel_chunk_duration - 1) // self.temporal_compression_factor + 1 - - @property - def temporal_compression_factor(self): - return self._temporal_compress_factor - - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - if num_pixel_frames == 1: - return 1 - assert ( - num_pixel_frames % self.pixel_chunk_duration == 0 - ), f"Temporal dimension {num_pixel_frames} is not divisible by chunk_length {self.pixel_chunk_duration}" - return num_pixel_frames // self.pixel_chunk_duration * self.latent_chunk_duration - - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - if num_latent_frames == 1: - return 1 - assert ( - num_latent_frames % self.latent_chunk_duration == 0 - ), f"Temporal dimension {num_latent_frames} is not divisible by chunk_length {self.latent_chunk_duration}" - return num_latent_frames // self.latent_chunk_duration * self.pixel_chunk_duration - - -class BasePretrainedImageVAE(BaseVAE): - """ - A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values - from a remote store, handles data type conversions, and normalization - using provided mean and standard deviation values for latent space representation. - Derived classes should load pre-trained encoder and decoder components from a remote store - - Attributes: - latent_mean (Tensor): The mean used for normalizing the latent representation. - latent_std (Tensor): The standard deviation used for normalizing the latent representation. - dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - - Args: - mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. - latent_ch (int, optional): Number of latent channels (default is 16). - is_image (bool, optional): Flag to indicate whether the output is an image (default is True). - is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). - """ - - def __init__( - self, - name: str, - mean_std_fp: str, - latent_ch: int = 16, - is_image: bool = True, - is_bf16: bool = True, - ) -> None: - super().__init__(latent_ch, name) - dtype = torch.bfloat16 if is_bf16 else torch.float32 - self.dtype = dtype - self.is_image = is_image - self.mean_std_fp = mean_std_fp - self.name = name - - self.backend_args = None - - self.register_mean_std(mean_std_fp) - - def register_mean_std(self, mean_std_fp: str) -> None: - latent_mean, latent_std = torch.load(mean_std_fp, map_location="cuda", weights_only=True) - target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] - self.register_buffer( - "latent_mean", - latent_mean.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - self.register_buffer( - "latent_std", - latent_std.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - Encode the input state to latent space; also handle the dtype conversion, mean and std scaling - """ - in_dtype = state.dtype - latent_mean = self.latent_mean.to(in_dtype) - latent_std = self.latent_std.to(in_dtype) - encoded_state = self.encoder(state.to(self.dtype)) - if isinstance(encoded_state, torch.Tensor): - pass - elif isinstance(encoded_state, tuple): - assert isinstance(encoded_state[0], torch.Tensor) - encoded_state = encoded_state[0] - else: - raise ValueError("Invalid type of encoded state") - return (encoded_state.to(in_dtype) - latent_mean) / latent_std - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decode the input latent to state; also handle the dtype conversion, mean and std scaling - """ - in_dtype = latent.dtype - latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) - return self.decoder(latent.to(self.dtype)).to(in_dtype) - - def reset_dtype(self, *args, **kwargs): - """ - Resets the data type of the encoder and decoder to the model's default data type. - - Args: - *args, **kwargs: Unused, present to allow flexibility in method calls. - """ - del args, kwargs - self.decoder.to(self.dtype) - self.encoder.to(self.dtype) - - -class JITVAE(BasePretrainedImageVAE): - """ - A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder - and decoder components from a remote store, handles data type conversions, and normalization - using provided mean and standard deviation values for latent space representation. - - Attributes: - encoder (Module): The JIT compiled encoder loaded from storage. - decoder (Module): The JIT compiled decoder loaded from storage. - latent_mean (Tensor): The mean used for normalizing the latent representation. - latent_std (Tensor): The standard deviation used for normalizing the latent representation. - dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - - Args: - enc_fp (str): File path to the encoder's JIT file on the remote store. - dec_fp (str): File path to the decoder's JIT file on the remote store. - name (str): Name of the model, used for differentiating cache file paths. - mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. - latent_ch (int, optional): Number of latent channels (default is 16). - is_image (bool, optional): Flag to indicate whether the output is an image (default is True). - is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). - """ - - def __init__( - self, - enc_fp: str, - dec_fp: str, - name: str, - mean_std_fp: str, - latent_ch: int = 16, - is_image: bool = True, - is_bf16: bool = True, - ): - super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) - self.load_encoder(enc_fp) - self.load_decoder(dec_fp) - - def load_encoder(self, enc_fp: str) -> None: - """ - Load the encoder from the remote store. - - Args: - - enc_fp (str): File path to the encoder's JIT file on the remote store. - """ - self.encoder = torch.jit.load(enc_fp, map_location="cuda") - self.encoder.eval() - for param in self.encoder.parameters(): - param.requires_grad = False - self.encoder.to(self.dtype) - - def load_decoder(self, dec_fp: str) -> None: - """ - Load the decoder from the remote store. - - Args: - - dec_fp (str): File path to the decoder's JIT file on the remote store. - """ - self.decoder = torch.jit.load(dec_fp, map_location="cuda") - self.decoder.eval() - for param in self.decoder.parameters(): - param.requires_grad = False - self.decoder.to(self.dtype) - - -# class StateDictVAE(BasePretrainedImageVAE): -# """ -# A Variational Autoencoder (VAE) that loads pre-trained weights into -# provided encoder and decoder components from a remote store, handles data type conversions, -# and normalization using provided mean and standard deviation values for latent space representation. - -# Attributes: -# encoder (Module): The encoder with weights loaded from storage. -# decoder (Module): The decoder with weights loaded from storage. -# latent_mean (Tensor): The mean used for normalizing the latent representation. -# latent_std (Tensor): The standard deviation used for normalizing the latent representation. -# dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - -# Args: -# enc_fp (str): File path to the encoder's JIT file on the remote store. -# dec_fp (str): File path to the decoder's JIT file on the remote store. -# vae (Module): Instance of VAE with not loaded weights -# name (str): Name of the model, used for differentiating cache file paths. -# mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. -# latent_ch (int, optional): Number of latent channels (default is 16). -# is_image (bool, optional): Flag to indicate whether the output is an image (default is True). -# is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). -# """ - -# def __init__( -# self, -# enc_fp: str, -# dec_fp: str, -# vae: torch.nn.Module, -# name: str, -# mean_std_fp: str, -# latent_ch: int = 16, -# is_image: bool = True, -# is_bf16: bool = True, -# ): -# super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) - -# self.load_encoder_and_decoder(enc_fp, dec_fp, vae) - -# def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None: -# """ -# Load the encoder from the remote store. - -# Args: -# - vae_fp (str): File path to the vae's state dict file on the remote store. -# - vae (str): VAE module into which weights will be loaded. -# """ -# state_dict_enc = load_from_s3_with_cache( -# enc_fp, -# f"vae/{self.name}_enc.jit", -# easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, -# backend_args=self.backend_args, -# ) - -# state_dict_dec = load_from_s3_with_cache( -# dec_fp, -# f"vae/{self.name}_dec.jit", -# easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, -# backend_args=self.backend_args, -# ) - -# jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict() -# jit_weights_state_dict = { -# k: v -# for k, v in jit_weights_state_dict.items() -# # Global variables captured by JIT -# if k -# not in ( -# "encoder.patcher.wavelets", -# "encoder.patcher._arange", -# "decoder.unpatcher.wavelets", -# "decoder.unpatcher._arange", -# ) -# } - -# vae.load_state_dict(jit_weights_state_dict) -# vae.eval() -# for param in vae.parameters(): -# param.requires_grad = False -# vae.to(self.dtype) - -# self.vae = vae -# self.encoder = self.vae.encode -# self.decoder = self.vae.decode - -# def reset_dtype(self, *args, **kwargs): -# """ -# Resets the data type of the encoder and decoder to the model's default data type. - -# Args: -# *args, **kwargs: Unused, present to allow flexibility in method calls. -# """ -# del args, kwargs -# self.vae.to(self.dtype) - - -# class SDVAE(BaseVAE): -# def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: -# super().__init__(channel=4, name="sd_vae") -# self.dtype = torch.bfloat16 -# self.register_buffer( -# "scale", -# torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), -# persistent=False, -# ) -# self.register_buffer( -# "bias", -# -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, -# persistent=False, -# ) -# self.batch_size = batch_size -# self.count_std = count_std -# self.is_downsample = is_downsample -# self.load_vae() -# self.reset_dtype() - -# def reset_dtype(self, *args, **kwargs): -# del args, kwargs -# self.vae.to(self.dtype) - -# @rank0_first -# def load_vae(self) -> None: -# os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" -# os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" -# import diffusers - -# vae_name = "stabilityai/sd-vae-ft-mse" -# try: -# vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) -# except: # noqa: E722 -# # Could not load the model from cache; try without local_files_only. -# vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) -# self.vae = vae.eval().requires_grad_(False) - -# @torch.no_grad() -# def encode(self, state: torch.Tensor) -> torch.Tensor: -# """ -# state : pixel range [-1, 1] -# """ -# if self.is_downsample: -# _h, _w = state.shape[-2:] -# state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) -# in_dtype = state.dtype -# state = state.to(self.dtype) -# state = (state + 1.0) / 2.0 -# latent_dist = self.vae.encode(state)["latent_dist"] -# mean, std = latent_dist.mean, latent_dist.std -# if self.count_std: -# latent = mean + torch.randn_like(mean) * std -# else: -# latent = mean -# latent = latent * self.scale -# latent = latent + self.bias -# return latent.to(in_dtype) - -# @torch.no_grad() -# def decode(self, latent: torch.Tensor) -> torch.Tensor: -# in_dtype = latent.dtype -# latent = latent.to(self.dtype) -# latent = latent - self.bias -# latent = latent / self.scale -# latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) -# if self.is_downsample: -# _h, _w = latent.shape[-2:] -# latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) -# return latent.to(in_dtype) * 2 - 1.0 - -# @property -# def spatial_compression_factor(self) -> int: -# return 8 - - -class VideoJITTokenizer(BasePretrainedVideoTokenizer, JITVAE, VideoTokenizerInterface): - """ - Instance of BasePretrainedVideoVAE that loads encoder and decoder from JIT scripted module file - """ - - def __init__( - self, - enc_fp: str, - dec_fp: str, - name: str, - mean_std_fp: str, - latent_ch: int = 16, - is_bf16: bool = True, - spatial_compression_factor: int = 16, - temporal_compression_factor: int = 8, - pixel_chunk_duration: int = 17, - max_enc_batch_size: int = 8, - max_dec_batch_size: int = 4, - spatial_resolution: str = "720", - ): - super().__init__(pixel_chunk_duration, temporal_compression_factor, max_enc_batch_size, max_dec_batch_size) - super(BasePretrainedVideoTokenizer, self).__init__(enc_fp, dec_fp, name, mean_std_fp, latent_ch, False, is_bf16) - - self._spatial_compression_factor = spatial_compression_factor - self._spatial_resolution = spatial_resolution - - @property - def spatial_compression_factor(self): - return self._spatial_compression_factor - - @property - def spatial_resolution(self) -> str: - return self._spatial_resolution - - -class JointImageVideoTokenizer(BaseVAE, VideoTokenizerInterface): - def __init__( - self, - image_vae: torch.nn.Module, - video_vae: torch.nn.Module, - name: str, - latent_ch: int = 16, - squeeze_for_image: bool = True, - ): - super().__init__(latent_ch, name) - self.image_vae = image_vae - self.video_vae = video_vae - self.squeeze_for_image = squeeze_for_image - - def encode_image(self, state: torch.Tensor) -> torch.Tensor: - if self.squeeze_for_image: - return self.image_vae.encode(state.squeeze(2)).unsqueeze(2) - return self.image_vae.encode(state) - - def decode_image(self, latent: torch.Tensor) -> torch.Tensor: - if self.squeeze_for_image: - return self.image_vae.decode(latent.squeeze(2)).unsqueeze(2) - return self.image_vae.decode(latent) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - B, C, T, H, W = state.shape - if T == 1: - return self.encode_image(state) - - return self.video_vae.encode(state) - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - B, C, T, H, W = latent.shape - if T == 1: - return self.decode_image(latent) - return self.video_vae.decode(latent) - - def reset_dtype(self, *args, **kwargs): - """ - Resets the data type of the encoder and decoder to the model's default data type. - - Args: - *args, **kwargs: Unused, present to allow flexibility in method calls. - """ - del args, kwargs - self.image_vae.reset_dtype() - self.video_vae.reset_dtype() - - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - if num_pixel_frames == 1: - return 1 - return self.video_vae.get_latent_num_frames(num_pixel_frames) - - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - if num_latent_frames == 1: - return 1 - return self.video_vae.get_pixel_num_frames(num_latent_frames) - - @property - def spatial_compression_factor(self): - return self.video_vae.spatial_compression_factor - - @property - def temporal_compression_factor(self): - return self.video_vae.temporal_compression_factor - - @property - def spatial_resolution(self) -> str: - return self.video_vae.spatial_resolution - - @property - def pixel_chunk_duration(self) -> int: - return self.video_vae.pixel_chunk_duration - - @property - def latent_chunk_duration(self) -> int: - return self.video_vae.latent_chunk_duration - - -class JointImageVideoSharedJITTokenizer(JointImageVideoTokenizer): - """ - First version of the ImageVideoVAE trained with Fitsum. - We have to use seperate mean and std for image and video due to non-causal nature of the model. - """ - - def __init__(self, image_vae: Module, video_vae: Module, name: str, latent_ch: int = 16): - super().__init__(image_vae, video_vae, name, latent_ch, squeeze_for_image=False) - assert isinstance(image_vae, JITVAE) - assert isinstance( - video_vae, VideoJITTokenizer - ), f"video_vae should be an instance of VideoJITVAE, got {type(video_vae)}" - # a hack to make the image_vae and video_vae share the same encoder and decoder - self.image_vae.encoder = self.video_vae.encoder - self.image_vae.decoder = self.video_vae.decoder diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/__init__.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit.py deleted file mode 100644 index 6e2d21dc13427d31e77f7ee397adb76095ce261a..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit.py +++ /dev/null @@ -1,1037 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. -It allows us easy to switch building blocks used and their order. Its instantiation includes -* transformer on fully flattened tokens -* factored spatial and temporal attention -* factored non-overlap spatial and temporal attention -* mixing of above attention types - -Limitations: - -* In favor of simplicity and cleanness, many ops are not fused and we can do better -* such as combining mutiple adaln MLPs into one inside one transformer block. -* we use reshape heavily, which may be not efficient when its occurs unnecessary CUDA memory copy - -Purpose: -* A prototype for testing different attention types and their combinations -* Idealy, we want to know where we should allocate our resources / FLOPS / memory via extensive empirical studies -""" - - -from collections.abc import Container -from typing import List, Optional, Tuple - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import nn -from torch.distributed import ProcessGroup, get_process_group_ranks -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.attention import get_normalization -from cosmos_transfer1.diffusion.module.position_embedding import ( - LearnableEmb3D, - LearnableEmb3D_FPS_Aware, - LearnablePosEmbAxis, - SinCosPosEmb, - SinCosPosEmb_FPS_Aware, - SinCosPosEmbAxis, - VideoRopePosition3DEmb, - VideoRopePositionEmb, -) -from cosmos_transfer1.diffusion.training.modules.blocks import ( - DITBuildingBlock, - FinalLayer, - GeneralDITTransformerBlock, - PatchEmbed, - SDXLTimestepEmbedding, - SDXLTimesteps, -) -from cosmos_transfer1.diffusion.training.tensor_parallel import gather_along_first_dim, scatter_along_first_dim -from cosmos_transfer1.utils import log - - -class GeneralDIT(nn.Module): - """ - A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. - Attributes: - max_img_h (int): Maximum height of the input images. - max_img_w (int): Maximum width of the input images. - max_frames (int): Maximum number of frames in the video sequence. - in_channels (int): Number of input channels (e.g., RGB channels for color images). - out_channels (int): Number of output channels. - patch_spatial (tuple of int): Spatial resolution of patches for input processing. - patch_temporal (int): Temporal resolution of patches for input processing. - concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. - block_config (str): Configuration of the transformer block, e.g., 'FA-CA-MLP', means - full attention, cross attention, and MLP in sequence in one transformer block. - model_channels (int): Base number of channels used throughout the model. - num_blocks (int): Number of residual blocks per resolution in the transformer. - num_heads (int): Number of heads in the multi-head self-attention layers. - spatial_attn_win_size (int): Window size for the spatial attention mechanism. - temporal_attn_win_size (int): Window size for the temporal attention mechanism. - mlp_ratio (float): Expansion ratio for the MLP (multi-layer perceptron) blocks in the transformer. - use_memory_save (bool): If True, utilizes checkpointing to reduce memory usage during training. (Deprecated) - use_checkpoint (bool): If True, utilizes checkpointing to reduce memory usage during training for all blocks. - crossattn_emb_channels (int): Number of embedding channels used in the cross-attention layers. - use_cross_attn_mask (bool): If True, applies a mask during cross-attention operations to manage sequence alignment. - pos_emb_cls (str): Type of positional embeddings used ('sincos' for sinusoidal or other types). - pos_emb_learnable (bool): Specifies if positional embeddings are learnable. - pos_emb_interpolation (str): Method used for interpolating positional embeddings, e.g., 'crop' for cropping adjustments. - block_x_format (str, optional): The format of the input tensor for the transformer block. Defaults to "BTHWD". Only support 'BTHWD' and 'THWBD'. - legacy_patch_emb (bool): If True, applies 3D convolutional layers for video inputs, otherwise, use Linear! This is for backward compatibility. - rope_h_extrapolation_ratio (float): Ratio of the height extrapolation for the rope positional embedding. - rope_w_extrapolation_ratio (float): Ratio of the width extrapolation for the rope positional embedding. - rope_t_extrapolation_ratio (float): Ratio of the temporal extrapolation for the rope positional embedding. - Note: - block_config support block type: - * spatial_sa, ssa: spatial self attention - * temporal_sa, tsa: temporal self attention - * cross_attn, ca: cross attention - * full_attn: full attention on all flatten tokens - * mlp, ff: feed forward block - * use '-' to separate different building blocks, e.g., 'FA-CA-MLP' means full attention, cross attention, and MLP in sequence in one transformer block. - - Example: - >>> # full attention, cross attention, and MLP - >>> option1_block_config = 'FA-CA-MLP' - >>> model_1 = GeneralDIT( - max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, - patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, - num_heads=16, mlp_ratio=4.0, - spatial_attn_win_size=1, temporal_attn_win_size=1, - block_config=option1_block_config - ) - >>> option2_block_config = 'SSA-CA-MLP-TSA-CA-MLP' - >>> model_2 = GeneralDIT( - max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, - patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, - num_heads=16, mlp_ratio=4.0, - spatial_attn_win_size=1, temporal_attn_win_size=1, - block_config=option2_block_config - ) - >>> # option3 model - >>> model_3 = GeneralDIT( - max_img_h=64, max_img_w=64, max_frames=32, in_channels=16, out_channels=16, - patch_spatial=2, patch_temporal=1, model_channels=768, num_blocks=10, - num_heads=16, mlp_ratio=4.0, - spatial_attn_win_size=1, temporal_attn_win_size=2, - block_config=option2_block_config - ) - >>> # Process input tensor through the model - >>> output = model(input_tensor) - """ - - def __init__( - self, - max_img_h: int, - max_img_w: int, - max_frames: int, - in_channels: int, - out_channels: int, - patch_spatial: tuple, - patch_temporal: int, - concat_padding_mask: bool = True, - # attention settings - block_config: str = "FA-CA-MLP", - model_channels: int = 768, - num_blocks: int = 10, - num_heads: int = 16, - window_block_indexes: list = [], # index for window attention block - window_sizes: list = [], # window size for window attention block in the order of T, H, W - spatial_attn_win_size: int = 1, - temporal_attn_win_size: int = 1, - mlp_ratio: float = 4.0, - use_memory_save: bool = False, - use_checkpoint: bool = False, - block_x_format: str = "BTHWD", - # cross attention settings - crossattn_emb_channels: int = 1024, - use_cross_attn_mask: bool = False, - # positional embedding settings - pos_emb_cls: str = "sincos", - pos_emb_learnable: bool = False, - pos_emb_interpolation: str = "crop", - min_fps: int = 1, # 1 for getty video - max_fps: int = 30, # 120 for getty video but let's use 30 - additional_timestamp_channels: dict = None, # Follow SDXL, in format of {condition_name : dimension} - affline_emb_norm: bool = False, # whether or not to normalize the affine embedding - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - layer_mask: list = None, # whether or not a layer is used. For controlnet encoder - legacy_patch_emb: bool = True, - rope_h_extrapolation_ratio: float = 1.0, - rope_w_extrapolation_ratio: float = 1.0, - rope_t_extrapolation_ratio: float = 1.0, - extra_per_block_abs_pos_emb: bool = False, - extra_per_block_abs_pos_emb_type: str = "sincos", - extra_h_extrapolation_ratio: float = 1.0, - extra_w_extrapolation_ratio: float = 1.0, - extra_t_extrapolation_ratio: float = 1.0, - ) -> None: - super().__init__() - self.max_img_h = max_img_h - self.max_img_w = max_img_w - self.max_frames = max_frames - self.in_channels = in_channels - self.out_channels = out_channels - self.patch_spatial = patch_spatial - self.patch_temporal = patch_temporal - self.num_heads = num_heads - self.num_blocks = num_blocks - self.model_channels = model_channels - self.use_cross_attn_mask = use_cross_attn_mask - self.concat_padding_mask = concat_padding_mask - # positional embedding settings - self.pos_emb_cls = pos_emb_cls - self.pos_emb_learnable = pos_emb_learnable - self.pos_emb_interpolation = pos_emb_interpolation - self.min_fps = min_fps - self.max_fps = max_fps - self.additional_timestamp_channels = additional_timestamp_channels - self.affline_emb_norm = affline_emb_norm - self.legacy_patch_emb = legacy_patch_emb - self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio - self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio - self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio - self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb - self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() - self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio - self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio - self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio - - self.build_patch_embed() - self.build_pos_embed() - self.cp_group = None - self.sequence_parallel = getattr(parallel_state, "sequence_parallel", False) - self.block_x_format = block_x_format - self.use_adaln_lora = use_adaln_lora - self.adaln_lora_dim = adaln_lora_dim - self.t_embedder = nn.Sequential( - SDXLTimesteps(model_channels), - SDXLTimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), - ) - - self.blocks = nn.ModuleDict() - self.block_config = block_config - self.use_memory_save = use_memory_save - self.use_checkpoint = use_checkpoint - - assert ( - len(window_block_indexes) == 0 or block_config == "FA-CA-MLP" - ), "Block config must be FA-CA-MLP if using a combination of window attention and global attention" - - layer_mask = [False] * num_blocks if layer_mask is None else layer_mask - assert ( - len(layer_mask) == num_blocks - ), f"Layer mask length {len(layer_mask)} does not match num_blocks {num_blocks}" - for idx in range(num_blocks): - if layer_mask[idx]: - continue - self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( - x_dim=model_channels, - context_dim=crossattn_emb_channels, - num_heads=num_heads, - block_config=block_config, - window_sizes=( - window_sizes if idx in window_block_indexes else [] - ), # There will be bug if using "WA-CA-MLP" - mlp_ratio=mlp_ratio, - spatial_attn_win_size=spatial_attn_win_size, - temporal_attn_win_size=temporal_attn_win_size, - x_format=self.block_x_format, - use_adaln_lora=use_adaln_lora, - adaln_lora_dim=adaln_lora_dim, - use_checkpoint=use_checkpoint, - ) - - self.build_decode_head() - self.build_additional_timestamp_embedder() - if self.affline_emb_norm: - log.critical("Building affine embedding normalization layer") - self.affline_norm = get_normalization("R", model_channels) - else: - self.affline_norm = nn.Identity() - self.init_weights() - - if self.use_memory_save: - log.critical("Using checkpointing to save memory! only verified in 14B base model training!") - for block in self.blocks.values(): - block.set_memory_save() - - def init_weights(self): - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize timestep embedding - nn.init.normal_(self.t_embedder[1].linear_1.weight, std=0.02) - if self.t_embedder[1].linear_1.bias is not None: - nn.init.constant_(self.t_embedder[1].linear_1.bias, 0) - nn.init.normal_(self.t_embedder[1].linear_2.weight, std=0.02) - if self.t_embedder[1].linear_2.bias is not None: - nn.init.constant_(self.t_embedder[1].linear_2.bias, 0) - - # Zero-out adaLN modulation layers in DiT blocks: - for transformer_block in self.blocks.values(): - for block in transformer_block.blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - if block.adaLN_modulation[-1].bias is not None: - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Tensor parallel - if parallel_state.is_initialized() and parallel_state.get_tensor_model_parallel_world_size() > 1: - self.initialize_tensor_parallel_weights() - - def initialize_tensor_parallel_weights(self): - """ - Initialize weights for tensor parallel layers. - - This function performs the following steps: - 1. Retrieves the tensor parallel rank. - 2. Saves the current random state. - 3. Sets a new random seed based on the tensor parallel rank. - 4. Initializes weights for attention and MLP layers in each block. - 5. Restores the original random state. - - The use of different random seeds for each rank ensures - unique initializations across parallel processes. - """ - tp_rank = parallel_state.get_tensor_model_parallel_rank() - - # Save the current random state - rng_state = torch.get_rng_state() - - # Set a new random seed based on the tensor parallel rank - torch.manual_seed(tp_rank) - - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: - # Initialize weights for attention layers - torch.nn.init.xavier_uniform_(layer.block.attn.to_q[0].weight) - torch.nn.init.xavier_uniform_(layer.block.attn.to_k[0].weight) - torch.nn.init.xavier_uniform_(layer.block.attn.to_v[0].weight) - torch.nn.init.xavier_uniform_(layer.block.attn.to_out[0].weight) - elif layer.block_type in ["mlp", "ff"]: - # Initialize weights for MLP layers - torch.nn.init.xavier_uniform_(layer.block.layer1.weight) - torch.nn.init.xavier_uniform_(layer.block.layer2.weight) - else: - raise ValueError(f"Unknown block type {layer.block_type}") - - # Restore the original random state - torch.set_rng_state(rng_state) - - def build_decode_head(self): - self.final_layer = FinalLayer( - hidden_size=self.model_channels, - spatial_patch_size=self.patch_spatial, - temporal_patch_size=self.patch_temporal, - out_channels=self.out_channels, - use_adaln_lora=self.use_adaln_lora, - adaln_lora_dim=self.adaln_lora_dim, - ) - - def build_patch_embed(self): - ( - concat_padding_mask, - in_channels, - patch_spatial, - patch_temporal, - model_channels, - ) = ( - self.concat_padding_mask, - self.in_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - ) - in_channels = in_channels + 1 if concat_padding_mask else in_channels - self.x_embedder = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - keep_spatio=True, - legacy_patch_emb=self.legacy_patch_emb, - ) - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - if self.legacy_patch_emb: - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - def build_additional_timestamp_embedder(self): - if self.additional_timestamp_channels: - self.additional_timestamp_embedder = nn.ModuleDict() - for cond_name, cond_emb_channels in self.additional_timestamp_channels.items(): - log.critical( - f"Building additional timestamp embedder for {cond_name} with {cond_emb_channels} channels" - ) - self.additional_timestamp_embedder[cond_name] = nn.Sequential( - SDXLTimesteps(cond_emb_channels), - SDXLTimestepEmbedding(cond_emb_channels, cond_emb_channels), - ) - - def prepare_additional_timestamp_embedder(self, **kwargs): - condition_concat = [] - - for cond_name, embedder in self.additional_timestamp_embedder.items(): - condition_concat.append(embedder(kwargs[cond_name])[0]) - embedding = torch.cat(condition_concat, dim=1) - if embedding.shape[1] < self.model_channels: - embedding = nn.functional.pad(embedding, (0, self.model_channels - embedding.shape[1])) - return embedding - - def build_pos_embed(self): - if self.pos_emb_cls == "sincos": - cls_type = SinCosPosEmb - elif self.pos_emb_cls == "learnable": - cls_type = LearnableEmb3D - elif self.pos_emb_cls == "sincos_fps_aware": - cls_type = SinCosPosEmb_FPS_Aware - elif self.pos_emb_cls == "learnable_fps_aware": - cls_type = LearnableEmb3D_FPS_Aware - elif self.pos_emb_cls == "rope": - cls_type = VideoRopePositionEmb - elif self.pos_emb_cls == "rope3d": - cls_type = VideoRopePosition3DEmb - else: - raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") - - log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") - kwargs = dict( - model_channels=self.model_channels, - len_h=self.max_img_h // self.patch_spatial, - len_w=self.max_img_w // self.patch_spatial, - len_t=self.max_frames // self.patch_temporal, - max_fps=self.max_fps, - min_fps=self.min_fps, - is_learnable=self.pos_emb_learnable, - interpolation=self.pos_emb_interpolation, - head_dim=self.model_channels // self.num_heads, - h_extrapolation_ratio=self.rope_h_extrapolation_ratio, - w_extrapolation_ratio=self.rope_w_extrapolation_ratio, - t_extrapolation_ratio=self.rope_t_extrapolation_ratio, - ) - self.pos_embedder = cls_type( - **kwargs, - ) - - if self.extra_per_block_abs_pos_emb: - assert self.extra_per_block_abs_pos_emb_type in [ - "sincos", - "learnable", - ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" - kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio - kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio - kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio - if self.extra_per_block_abs_pos_emb_type == "sincos": - self.extra_pos_embedder = SinCosPosEmbAxis( - **kwargs, - ) - elif self.extra_per_block_abs_pos_emb_type == "learnable": - self.extra_pos_embedder = LearnablePosEmbAxis( - **kwargs, - ) - - def prepare_embedded_sequence( - self, - x_B_C_T_H_W: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. - - Args: - x_B_C_T_H_W (torch.Tensor): video - fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. - If None, a default value (`self.base_fps`) will be used. - padding_mask (Optional[torch.Tensor]): current it is not used - - Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: - - A tensor of shape (B, T, H, W, D) with the embedded sequence. - - An optional positional embedding tensor, returned only if the positional embedding class - (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. - - Notes: - - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. - - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. - - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using - the `self.pos_embedder` with the shape [T, H, W]. - - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` - with the fps tensor. - - Otherwise, the positional embeddings are generated without considering fps. - """ - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) - x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) - - if self.extra_per_block_abs_pos_emb: - extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) - else: - extra_pos_emb = None - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None, extra_pos_emb - - def decoder_head( - self, - x_B_T_H_W_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] - crossattn_mask: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - del crossattn_emb, crossattn_mask - B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape - x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") - x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) - # This is to ensure x_BT_HW_D has the correct shape because - # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). - x_BT_HW_D = x_BT_HW_D.view( - B * T_before_patchify // self.patch_temporal, - H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, - -1, - ) - x_B_D_T_H_W = rearrange( - x_BT_HW_D, - "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", - p1=self.patch_spatial, - p2=self.patch_spatial, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - t=self.patch_temporal, - B=B, - ) - return x_B_D_T_H_W - - def forward_before_blocks( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - """ - del kwargs - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - original_shape = x.shape - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) - - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = self.affline_norm(affline_emb_B_D) - - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - output = { - "x": x, - "affline_emb_B_D": affline_emb_B_D, - "crossattn_emb": crossattn_emb, - "crossattn_mask": crossattn_mask, - "rope_emb_L_1_1_D": rope_emb_L_1_1_D, - "adaln_lora_B_3D": adaln_lora_B_3D, - "original_shape": original_shape, - "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - } - return output - - def forward_blocks_regular( - self, - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, - regional_contexts=None, - region_masks=None, - ): - features = [] - for name, block in self.blocks.items(): - assert ( - self.blocks["block0"].x_format == block.x_format - ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" - x = block( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - - # Extract features - block_idx = int(name.split("block")[-1]) - if block_idx in feature_indices: - B, C, T, H, W = original_shape - H = H // self.patch_spatial - W = W // self.patch_spatial - T = T // self.patch_temporal - if self.sequence_parallel: - x_feat = gather_along_first_dim(x, parallel_state.get_tensor_model_parallel_group()) - x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) - else: - x_feat = x - if self.blocks["block0"].x_format == "THWBD": - x_B_T_H_W_D = rearrange(x_feat, "T H W B D -> B T H W D", T=T, H=H, W=W) - elif self.blocks["block0"].x_format == "BTHWD": - x_B_T_H_W_D = x_feat - else: - raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") - - features.append(x_B_T_H_W_D) - - if x_ctrl is not None and name in x_ctrl: - x = x + x_ctrl[name] - # If we have all of the features, we can exit early - if return_features_early and len(features) == len(feature_indices): - return features - - if self.blocks["block0"].x_format == "THWBD": - x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") - elif self.blocks["block0"].x_format == "BTHWD": - x_B_T_H_W_D = x - else: - raise ValueError(f"Unknown x_format {self.blocks[-1].x_format}") - - x_B_D_T_H_W = self.decoder_head( - x_B_T_H_W_D=x_B_T_H_W_D, - emb_B_D=affline_emb_B_D, - crossattn_emb=None, - origin_shape=original_shape, - crossattn_mask=None, - adaln_lora_B_3D=adaln_lora_B_3D, - ) - - if len(feature_indices) == 0: - # no features requested, return only the model output - return x_B_D_T_H_W - else: - # score and features; score, features - return x_B_D_T_H_W, features - - def forward_blocks_memory_save( - self, - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, - regional_contexts=None, - region_masks=None, - ): - x_before_gate = 0 - x_skip = rearrange(x, "T H W B D -> (T H W) B D") - assert self.blocks["block0"].x_format == "THWBD" - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_per_block_pos_emb = rearrange(extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "T H W B D -> (T H W) B D") - else: - extra_per_block_pos_emb = None - gate_L_B_D = 1.0 - - features = [] - for name, block in self.blocks.items(): - gate_L_B_D, x_before_gate, x_skip = block( - x_before_gate, - x_skip, - gate_L_B_D, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_per_block_pos_emb, - regional_contexts=regional_contexts, - region_masks=region_masks, - ) - - # Extract features. - # Convert the block index in the memory save mode to the block index in the regular mode. - block_idx = int(name.split("block")[-1]) - 1 - if block_idx in feature_indices: - B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape - H = H_before_patchify // self.patch_spatial - W = W_before_patchify // self.patch_spatial - T = T_before_patchify // self.patch_temporal - if self.sequence_parallel: - x_feat = gather_along_first_dim(x_skip, parallel_state.get_tensor_model_parallel_group()) - x_B_T_H_W_D = rearrange(x_feat, "(T H W) 1 1 B D -> B T H W D", T=T, H=H, W=W) - else: - x_feat = x_skip - x_B_T_H_W_D = rearrange(x_feat, "(T H W) B D -> B T H W D", T=T, H=H, W=W) - - features.append(x_B_T_H_W_D) - - new_name = f"block{block_idx}" - if x_ctrl is not None and new_name in x_ctrl: - x_ctrl_ = x_ctrl[new_name] - x_ctrl_ = rearrange(x_ctrl_, "T H W B D -> (T H W) B D") - x_skip = x_skip + x_ctrl_ - # If we have all of the features, we can exit early - if return_features_early and len(features) == len(feature_indices): - return features - - x_THW_B_D_before_gate = x_before_gate - x_THW_B_D_skip = x_skip - - B, C, T_before_patchify, H_before_patchify, W_before_patchify = original_shape - x_BT_HW_D_before_gate = rearrange( - x_THW_B_D_before_gate, - "(T H W) B D -> (B T) (H W) D", - T=T_before_patchify // self.patch_temporal, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - ) - x_BT_HW_D_skip = rearrange( - x_THW_B_D_skip, - "(T H W) B D -> (B T) (H W) D", - T=T_before_patchify // self.patch_temporal, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - ) - - x_BT_HW_D = self.final_layer.forward_with_memory_save( - x_BT_HW_D_before_gate=x_BT_HW_D_before_gate, - x_BT_HW_D_skip=x_BT_HW_D_skip, - gate_L_B_D=gate_L_B_D, - emb_B_D=affline_emb_B_D, - adaln_lora_B_3D=adaln_lora_B_3D, - ) - - # This is to ensure x_BT_HW_D has the correct shape because - # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). - x_BT_HW_D = x_BT_HW_D.view( - B * T_before_patchify // self.patch_temporal, - H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, - -1, - ) - x_B_D_T_H_W = rearrange( - x_BT_HW_D, - "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", - p1=self.patch_spatial, - p2=self.patch_spatial, - H=H_before_patchify // self.patch_spatial, - W=W_before_patchify // self.patch_spatial, - t=self.patch_temporal, - B=B, - ) - if len(feature_indices) == 0: - # no features requested, return only the model output - return x_B_D_T_H_W - else: - # score and features; score, features - return x_B_D_T_H_W, features - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - x_ctrl: Optional[dict] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - feature_indices: Optional[Container[int]] = None, - return_features_early: bool = False, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - feature_indices: A set of feature indices (a set of integers) decides which blocks - to extract features from. If the set is non-empty, then features will be returned. - By default, feature_indices=None means extract no features. - return_features_early: If true, the forward pass returns the features once the set is complete. - This means the forward pass will not finish completely and no final output is returned. - condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to augment condition input, the lvg model will condition on the condition_video_augment_sigma value; - we need forward_before_blocks pass to the forward_before_blocks function. - """ - if feature_indices is None: - feature_indices = {} - if return_features_early and len(feature_indices) == 0: - # Exit immediately if user requested this. - return [] - - inputs = self.forward_before_blocks( - x=x, - timesteps=timesteps, - crossattn_emb=crossattn_emb, - crossattn_mask=crossattn_mask, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - condition_video_augment_sigma=condition_video_augment_sigma, - **kwargs, - ) - x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( - inputs["x"], - inputs["affline_emb_B_D"], - inputs["crossattn_emb"], - inputs["crossattn_mask"], - inputs["rope_emb_L_1_1_D"], - inputs["adaln_lora_B_3D"], - inputs["original_shape"], - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"] - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - assert ( - x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape - ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" - - if self.use_memory_save: - return self.forward_blocks_memory_save( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, - ) - - return self.forward_blocks_regular( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D, - adaln_lora_B_3D, - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - feature_indices, - original_shape, - x_ctrl, - return_features_early, - ) - - @property - def fsdp_wrap_block_cls(self): - return DITBuildingBlock - - def enable_context_parallel(self, cp_group: ProcessGroup): - cp_ranks = get_process_group_ranks(cp_group) - cp_size = len(cp_ranks) - # Set these attributes for spliting the data after embedding. - self.cp_group = cp_group - # Set these attributes for computing the loss. - self.cp_size = cp_size - - self.pos_embedder.enable_context_parallel(cp_group) - if self.extra_per_block_abs_pos_emb: - self.extra_pos_embedder.enable_context_parallel(cp_group) - # Loop through the model to set up context parallel. - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["mlp", "ff"]: - continue - elif layer.block_type in ["cross_attn", "ca"]: - continue - else: - layer.block.attn.attn_op.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) - - log.debug(f"[CP] Enable context parallelism with size {cp_size}") - - def disable_context_parallel(self): - self.cp_group = None - self.cp_size = None - - self.pos_embedder.disable_context_parallel() - if self.extra_per_block_abs_pos_emb: - self.extra_pos_embedder.disable_context_parallel() - - # Loop through the model to disable context parallel. - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["mlp", "ff"]: - continue - elif layer.block_type in ["cross_attn", "ca"]: - continue - else: - layer.block.attn.attn_op.cp_group = None - layer.block.attn.attn_op.cp_ranks = None - layer.block.attn.attn_op.cp_stream = None - - log.debug("[CP] Disable context parallelism.") - - def enable_sequence_parallel(self): - self._set_sequence_parallel(True) - - def disable_sequence_parallel(self): - self._set_sequence_parallel(False) - - def _set_sequence_parallel(self, status: bool): - self.sequence_parallel = status - self.final_layer.sequence_parallel = status - for block in self.blocks.values(): - for layer in block.blocks: - if layer.block_type in ["full_attn", "fa", "cross_attn", "ca"]: - layer.block.attn.to_q[0].sequence_parallel = status - layer.block.attn.to_k[0].sequence_parallel = status - layer.block.attn.to_v[0].sequence_parallel = status - layer.block.attn.to_out[0].sequence_parallel = status - layer.block.attn.attn_op.sequence_parallel = status - elif layer.block_type in ["mlp", "ff"]: - layer.block.layer1.sequence_parallel = status - layer.block.layer2.sequence_parallel = status - else: - raise ValueError(f"Unknown block type {layer.block_type}") - - @property - def is_context_parallel_enabled(self): - return self.cp_group is not None diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc.py deleted file mode 100644 index ebb151c009ec05ac1ffa05f36195a05ee33cc7a7..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc.py +++ /dev/null @@ -1,403 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -ControlNet Encoder based on GeneralDIT -""" - -from typing import List, Optional, Tuple - -import numpy as np -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import nn -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.blocks import zero_module -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.training.modules.blocks import PatchEmbed -from cosmos_transfer1.diffusion.training.networks.general_dit_video_conditioned import ( - VideoExtendGeneralDIT as GeneralDIT, -) -from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim - - -class GeneralDITEncoder(GeneralDIT): - """ - ControlNet Encoder based on GeneralDIT. Heavily borrowed from GeneralDIT with minor modifications. - """ - - def __init__(self, *args, **kwargs): - hint_channels = kwargs.pop("hint_channels", 16) - self.dropout_ctrl_branch = kwargs.pop("dropout_ctrl_branch", 0.5) - num_control_blocks = kwargs.pop("num_control_blocks", None) - if num_control_blocks is not None: - assert num_control_blocks > 0 and num_control_blocks <= kwargs["num_blocks"] - kwargs["layer_mask"] = [False] * num_control_blocks + [True] * (kwargs["num_blocks"] - num_control_blocks) - self.random_drop_control_blocks = kwargs.pop("random_drop_control_blocks", False) - super().__init__(*args, **kwargs) - num_blocks = self.num_blocks - model_channels = self.model_channels - layer_mask = kwargs.get("layer_mask", None) - layer_mask = [False] * num_blocks if layer_mask is None else layer_mask - self.layer_mask = layer_mask - self.hint_channels = hint_channels - self.build_hint_patch_embed() - hint_nf = [16, 16, 32, 32, 96, 96, 256] - nonlinearity = nn.SiLU() - input_hint_block = [nn.Linear(model_channels, hint_nf[0]), nonlinearity] - for i in range(len(hint_nf) - 1): - input_hint_block += [nn.Linear(hint_nf[i], hint_nf[i + 1]), nonlinearity] - self.input_hint_block = nn.Sequential(*input_hint_block) - # Initialize weights - self.init_weights() - self.zero_blocks = nn.ModuleDict() - for idx in range(num_blocks): - if layer_mask[idx]: - continue - self.zero_blocks[f"block{idx}"] = zero_module(nn.Linear(model_channels, model_channels)) - self.input_hint_block.append(zero_module(nn.Linear(hint_nf[-1], model_channels))) - - def _set_sequence_parallel(self, status: bool): - self.zero_blocks.sequence_parallel = status - self.input_hint_block.sequence_parallel = status - super()._set_sequence_parallel(status) - - def build_hint_patch_embed(self): - concat_padding_mask, in_channels, patch_spatial, patch_temporal, model_channels = ( - self.concat_padding_mask, - self.hint_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - ) - in_channels = in_channels + 1 if concat_padding_mask else in_channels - self.x_embedder2 = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - keep_spatio=True, - legacy_patch_emb=self.legacy_patch_emb, - ) - - if self.legacy_patch_emb: - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.x_embedder2.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - def prepare_hint_embedded_sequence( - self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[2], 1, 1)], - dim=1, - ) - - x_B_T_H_W_D = self.x_embedder2(x_B_C_T_H_W) - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps) - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None - - def encode_hint( - self, - hint: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - ) -> torch.Tensor: - assert hint.size(1) <= self.hint_channels, f"Expected hint channels <= {self.hint_channels}, got {hint.size(1)}" - if hint.size(1) < self.hint_channels: - padding_shape = list(hint.shape) - padding_shape[1] = self.hint_channels - hint.size(1) - hint = torch.cat([hint, torch.zeros(*padding_shape, dtype=hint.dtype, device=hint.device)], dim=1) - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) - - hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - T, H, W, B, D = hint.shape - hint = hint.view(T * H * W, 1, 1, B, -1) - hint = scatter_along_first_dim(hint, tp_group) - - guided_hint = self.input_hint_block(hint) - return guided_hint - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - hint_key: Optional[str] = None, - base_model: Optional[nn.Module] = None, - control_weight: Optional[float] = 1.0, - num_layers_to_use: Optional[int] = -1, - condition_video_input_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - """ - # record the input as they are replaced in this forward - x_input = x - crossattn_emb_input = crossattn_emb - crossattn_mask_input = crossattn_mask - condition_video_input_mask_input = condition_video_input_mask - - hint = kwargs.pop(hint_key) - if hint is None: - print("using none hint") - return base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_input_mask=condition_video_input_mask_input, - **kwargs, - ) - if hasattr(self, "hint_encoders"): # for multicontrol - guided_hints = [] - for i in range(hint.shape[1]): - self.input_hint_block = self.hint_encoders[i].input_hint_block - self.pos_embedder = self.hint_encoders[i].pos_embedder - self.x_embedder2 = self.hint_encoders[i].x_embedder2 - guided_hints += [self.encode_hint(hint[:, i], fps=fps, padding_mask=padding_mask, data_type=data_type)] - else: - guided_hints = self.encode_hint(hint, fps=fps, padding_mask=padding_mask, data_type=data_type) - guided_hints = torch.chunk(guided_hints, hint.shape[0] // x.shape[0], dim=3) - # Only support multi-control at inference time - assert len(guided_hints) == 1 or not torch.is_grad_enabled() - - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - B, C, T, H, W = x.shape - if data_type == DataType.VIDEO: - if condition_video_input_mask is not None: - if self.cp_group is not None: - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=2, cp_group=self.cp_group - ) - input_list = [x, condition_video_input_mask] - x = torch.cat(input_list, dim=1) - - elif data_type == DataType.IMAGE: - # For image, we dont have condition_video_input_mask, or condition_video_pose - # We need to add the extra channel for video condition mask - padding_channels = self.in_channels - x.shape[1] - x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) - else: - assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" - - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - outs = {} - - # (Experimental, not used in the released model) if also training base model, sometimes drop the - # controlnet branch to only train base branch. This is to prevent the network become dependent on - # controlnet branch and make control weight useless. - is_training = torch.is_grad_enabled() - is_training_base_model = any(p.requires_grad for p in base_model.parameters()) - if is_training and is_training_base_model: - coin_flip = torch.rand(B).to(x.device) > self.dropout_ctrl_branch # prob for only training base model - if self.blocks["block0"].x_format == "THWBD": - coin_flip = coin_flip[None, None, None, :, None] - elif self.blocks["block0"].x_format == "BTHWD": - coin_flip = coin_flip[:, None, None, None, None] - else: - coin_flip = 1 - - num_control_blocks = self.layer_mask.index(True) - num_layers_to_use = num_control_blocks - control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] - - if isinstance(control_weight, torch.Tensor): - if control_weight.ndim == 0: # Single scalar tensor - control_weight = [float(control_weight)] * len(guided_hints) - elif control_weight.ndim == 1: # List of scalar weights - control_weight = [float(w) for w in control_weight] - else: # Spatial-temporal weight maps - control_weight = [w for w in control_weight] # Keep as tensor - else: - control_weight = [control_weight] * len(guided_hints) - - x_before_blocks = x.clone() - for i, guided_hint in enumerate(guided_hints): - x = x_before_blocks - if hasattr(self, "hint_encoders"): # for multicontrol - blocks = self.hint_encoders[i].blocks - zero_blocks = self.hint_encoders[i].zero_blocks - t_embedder = self.hint_encoders[i].t_embedder - affline_norm = self.hint_encoders[i].affline_norm - self.x_embedder = self.hint_encoders[i].x_embedder - self.extra_pos_embedder = self.hint_encoders[i].extra_pos_embedder - else: - blocks = self.blocks - zero_blocks = self.zero_blocks - t_embedder = self.t_embedder - affline_norm = self.affline_norm - - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, fps=fps, padding_mask=padding_mask - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = affline_norm(affline_emb_B_D) - - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - - for idx, (name, block) in enumerate(blocks.items()): - assert ( - blocks["block0"].x_format == block.x_format - ), f"First block has x_format {blocks[0].x_format}, got {block.x_format}" - x = block( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - ) - if guided_hint is not None: - x = x + guided_hint - guided_hint = None - - gate = control_gate_per_layer[idx] - if isinstance(control_weight[i], (float, int)) or control_weight[i].ndim < 2: - hint_val = zero_blocks[name](x) * control_weight[i] * coin_flip * gate - else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] - control_feat = zero_blocks[name](x) - # Get current feature dimensions - weight_map = control_weight[i] # [B, 1, T, H, W] - # Reshape to match THWBD format - weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] - weight_map = weight_map.view(T * H * W, 1, 1, B, 1) - - if self.sequence_parallel: - weight_map = scatter_along_first_dim(weight_map, tp_group) - - hint_val = control_feat * weight_map * coin_flip * gate - - if name not in outs: - outs[name] = hint_val - else: - outs[name] += hint_val - - output = base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - x_ctrl=outs, - condition_video_input_mask=condition_video_input_mask_input, - **kwargs, - ) - return output diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc_multicamera.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc_multicamera.py deleted file mode 100644 index 6f606b3073864a291108c5ff463229ce2b413d19..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_ctrl_enc_multicamera.py +++ /dev/null @@ -1,562 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -ControlNet Encoder based on GeneralDIT -""" - -from typing import List, Optional, Tuple - -import numpy as np -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import nn -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.blocks import zero_module -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.training.modules.blocks import PatchEmbed -from cosmos_transfer1.diffusion.training.networks.general_dit_multi_camera import MultiCameraGeneralDIT -from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim -from cosmos_transfer1.utils import log - - -class GeneralDITMulticamEncoder(MultiCameraGeneralDIT): - """ - ControlNet Encoder based on GeneralDIT. Heavily borrowed from GeneralDIT with minor modifications. - """ - - def __init__(self, *args, in_channels, is_extend_model=False, **kwargs): - self.is_extend_model = is_extend_model - if is_extend_model: - new_input_channels = in_channels + 1 - log.info(f"Updating input channels to {new_input_channels} to accomodate cond_mask") - else: - new_input_channels = in_channels - - if kwargs.get("add_augment_sigma_embedding", None) is not None: - self.add_augment_sigma_embedding = kwargs.pop("add_augment_sigma_embedding") - else: - self.add_augment_sigma_embedding = False - hint_channels = kwargs.pop("hint_channels", 16) - self.dropout_ctrl_branch = kwargs.pop("dropout_ctrl_branch", 0.5) - num_control_blocks = kwargs.pop("num_control_blocks", None) - if num_control_blocks is not None: - assert num_control_blocks > 0 and num_control_blocks <= kwargs["num_blocks"] - kwargs["layer_mask"] = [False] * num_control_blocks + [True] * (kwargs["num_blocks"] - num_control_blocks) - self.random_drop_control_blocks = kwargs.pop("random_drop_control_blocks", False) - super().__init__(*args, in_channels=new_input_channels, **kwargs) - num_blocks = self.num_blocks - model_channels = self.model_channels - layer_mask = kwargs.get("layer_mask", None) - layer_mask = [False] * num_blocks if layer_mask is None else layer_mask - self.layer_mask = layer_mask - self.hint_channels = hint_channels - self.build_hint_patch_embed() - hint_nf = [16, 16, 32, 32, 96, 96, 256] - nonlinearity = nn.SiLU() - input_hint_block = [nn.Linear(model_channels, hint_nf[0]), nonlinearity] - for i in range(len(hint_nf) - 1): - input_hint_block += [nn.Linear(hint_nf[i], hint_nf[i + 1]), nonlinearity] - self.input_hint_block = nn.Sequential(*input_hint_block) - # Initialize weights - self.init_weights() - self.zero_blocks = nn.ModuleDict() - for idx in range(num_blocks): - if layer_mask[idx]: - continue - self.zero_blocks[f"block{idx}"] = zero_module(nn.Linear(model_channels, model_channels)) - self.input_hint_block.append(zero_module(nn.Linear(hint_nf[-1], model_channels))) - - def _set_sequence_parallel(self, status: bool): - self.zero_blocks.sequence_parallel = status - self.input_hint_block.sequence_parallel = status - super()._set_sequence_parallel(status) - - def build_hint_patch_embed(self): - concat_padding_mask, in_channels, patch_spatial, patch_temporal, model_channels = ( - self.concat_padding_mask, - self.hint_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - ) - in_channels = in_channels + 1 if concat_padding_mask else in_channels - self.x_embedder2 = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - keep_spatio=True, - legacy_patch_emb=self.legacy_patch_emb, - ) - - if self.legacy_patch_emb: - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - w = self.x_embedder2.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - def prepare_hint_embedded_sequence( - self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[2], 1, 1)], - dim=1, - ) - - x_B_T_H_W_D = self.x_embedder2(x_B_C_T_H_W) - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps) - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None - - def encode_hint( - self, - hint: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - ) -> torch.Tensor: - assert hint.size(1) <= self.hint_channels, f"Expected hint channels <= {self.hint_channels}, got {hint.size(1)}" - if hint.size(1) < self.hint_channels: - padding_shape = list(hint.shape) - padding_shape[1] = self.hint_channels - hint.size(1) - hint = torch.cat([hint, torch.zeros(*padding_shape, dtype=hint.dtype, device=hint.device)], dim=1) - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - hint_B_T_H_W_D, _ = self.prepare_hint_embedded_sequence(hint, fps=fps, padding_mask=padding_mask) - - if self.blocks["block0"].x_format == "THWBD": - hint = rearrange(hint_B_T_H_W_D, "B T H W D -> T H W B D") - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - T, H, W, B, D = hint.shape - hint = hint.view(T * H * W, 1, 1, B, -1) - hint = scatter_along_first_dim(hint, tp_group) - elif self.blocks["block0"].x_format == "BTHWD": - hint = hint_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - - guided_hint = self.input_hint_block(hint) - return guided_hint - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - hint_key: Optional[str] = None, - base_model: Optional[nn.Module] = None, - control_weight: Optional[float] = 1.0, - num_layers_to_use: Optional[int] = -1, - condition_video_input_mask: Optional[torch.Tensor] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor | List[torch.Tensor] | Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - """ - # record the input as they are replaced in this forward - x_input = x - frame_repeat = kwargs.get("frame_repeat", None) - original_shape = x.shape - crossattn_emb_input = crossattn_emb - crossattn_mask_input = crossattn_mask - condition_video_input_mask_input = condition_video_input_mask - hint = kwargs.pop(hint_key) - if hint is None: - log.info("using none hint") - return base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_input_mask=condition_video_input_mask_input, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - view_indices_B_T=view_indices_B_T, - **kwargs, - ) - - if hasattr(self, "hint_encoders"): # for multicontrol - guided_hints = [] - for i in range(hint.shape[1]): - self.input_hint_block = self.hint_encoders[i].input_hint_block - self.pos_embedder = self.hint_encoders[i].pos_embedder - self.x_embedder2 = self.hint_encoders[i].x_embedder2 - guided_hints += [self.encode_hint(hint[:, i], fps=fps, padding_mask=padding_mask, data_type=data_type)] - else: - guided_hints = self.encode_hint(hint, fps=fps, padding_mask=padding_mask, data_type=data_type) - guided_hints = torch.chunk(guided_hints, hint.shape[0] // x.shape[0], dim=3) - # Only support multi-control at inference time - assert len(guided_hints) == 1 or not torch.is_grad_enabled() - - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - - B, C, T, H, W = x.shape - if data_type == DataType.VIDEO: - if condition_video_input_mask is not None: - if self.cp_group is not None: - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views - ) - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=3, cp_group=self.cp_group - ) - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views - ) - input_list = [x, condition_video_input_mask] - x = torch.cat( - input_list, - dim=1, - ) - - elif data_type == DataType.IMAGE: - # For image, we dont have condition_video_input_mask, or condition_video_pose - # We need to add the extra channel for video condition mask - padding_channels = self.in_channels - x.shape[1] - x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) - else: - assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" - - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - if self.blocks["block0"].x_format == "THWBD": - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - outs = {} - - # If also training base model, sometimes drop the controlnet branch to only train base branch. - # This is to prevent the network become dependent on controlnet branch and make control weight useless. - is_training = torch.is_grad_enabled() - is_training_base_model = any(p.requires_grad for p in base_model.parameters()) - if is_training and is_training_base_model: - coin_flip = torch.rand(B).to(x.device) > self.dropout_ctrl_branch # prob for only training base model - if self.blocks["block0"].x_format == "THWBD": - coin_flip = coin_flip[None, None, None, :, None] - elif self.blocks["block0"].x_format == "BTHWD": - coin_flip = coin_flip[:, None, None, None, None] - else: - coin_flip = 1 - - num_control_blocks = self.layer_mask.index(True) - if self.random_drop_control_blocks: - if is_training: # Use a random number of layers during training. - num_layers_to_use = np.random.randint(num_control_blocks) + 1 - elif num_layers_to_use == -1: # Evaluate using all the layers. - num_layers_to_use = num_control_blocks - else: # Use the specified number of layers during inference. - pass - else: # Use all of the layers. - num_layers_to_use = num_control_blocks - control_gate_per_layer = [i < num_layers_to_use for i in range(num_control_blocks)] - - if isinstance(control_weight, torch.Tensor): - if control_weight.ndim == 0: # Single scalar tensor - control_weight = [float(control_weight)] * len(guided_hints) - elif control_weight.ndim == 1: # List of scalar weights - control_weight = [float(w) for w in control_weight] - else: # Spatial-temporal weight maps - control_weight = [w for w in control_weight] # Keep as tensor - else: - control_weight = [control_weight] * len(guided_hints) - - # max_norm = {} - x_before_blocks = x.clone() - for i, guided_hint in enumerate(guided_hints): - x = x_before_blocks - if hasattr(self, "hint_encoders"): # for multicontrol - blocks = self.hint_encoders[i].blocks - zero_blocks = self.hint_encoders[i].zero_blocks - t_embedder = self.hint_encoders[i].t_embedder - affline_norm = self.hint_encoders[i].affline_norm - self.x_embedder = self.hint_encoders[i].x_embedder - self.extra_pos_embedder = self.hint_encoders[i].extra_pos_embedder - else: - blocks = self.blocks - zero_blocks = self.zero_blocks - t_embedder = self.t_embedder - affline_norm = self.affline_norm - - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - frame_repeat=frame_repeat, - view_indices_B_T=view_indices_B_T, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = affline_norm(affline_emb_B_D) - - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - - for idx, (name, block) in enumerate(blocks.items()): - assert ( - blocks["block0"].x_format == block.x_format - ), f"First block has x_format {blocks[0].x_format}, got {block.x_format}" - x = block( - x, - affline_emb_B_D, - crossattn_emb, - crossattn_mask, - rope_emb_L_1_1_D=rope_emb_L_1_1_D, - adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - ) - if guided_hint is not None: - x = x + guided_hint - guided_hint = None - - gate = control_gate_per_layer[idx] - if isinstance(control_weight[i], (float, int)) or control_weight[i].ndim < 2: - hint_val = zero_blocks[name](x) * control_weight[i] * coin_flip * gate - else: # Spatial-temporal weights [num_controls, B, 1, T, H, W] - control_feat = zero_blocks[name](x) - - # Get current feature dimensions - if self.blocks["block0"].x_format == "THWBD": - weight_map = control_weight[i] # [B, 1, T, H, W] - - if weight_map.shape[2:5] != (T, H, W): - assert weight_map.shape[2] == 8 * (T - 1) + 1 - weight_map_i = [ - torch.nn.functional.interpolate( - weight_map[:, :, :1, :, :], - size=(1, H, W), - mode="trilinear", - align_corners=False, - ) - ] - for wi in range(1, weight_map.shape[2], 8): - weight_map_i += [ - torch.nn.functional.interpolate( - weight_map[:, :, wi : wi + 8], - size=(1, H, W), - mode="trilinear", - align_corners=False, - ) - ] - weight_map = torch.cat(weight_map_i, dim=2) - - # Reshape to match THWBD format - weight_map = weight_map.permute(2, 3, 4, 0, 1) # [T, H, W, B, 1] - weight_map = weight_map.view(T * H * W, 1, 1, B, 1) - if self.sequence_parallel: - weight_map = scatter_along_first_dim(weight_map, tp_group) - - else: # BTHWD format - raise NotImplementedError("BTHWD format for weight map is not implemented yet.") - hint_val = control_feat * weight_map * coin_flip * gate - - if name not in outs: - outs[name] = hint_val - else: - outs[name] += hint_val - - output = base_model.net.forward( - x=x_input, - timesteps=timesteps, - crossattn_emb=crossattn_emb_input, - crossattn_mask=crossattn_mask_input, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - x_ctrl=outs, - condition_video_input_mask=condition_video_input_mask_input, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - view_indices_B_T=view_indices_B_T, - **kwargs, - ) - return output - - -class VideoExtendGeneralDITMulticamEncoder(GeneralDITMulticamEncoder): - def __init__(self, *args, in_channels, add_augment_sigma_embedding=False, **kwargs): - self.add_augment_sigma_embedding = add_augment_sigma_embedding - # extra channel for video condition mask - super().__init__(*args, in_channels=in_channels, **kwargs) - log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - hint_key: Optional[str] = None, - base_model: Optional[nn.Module] = None, - control_weight: Optional[float] = 1.0, - num_layers_to_use: Optional[int] = -1, - video_cond_bool: Optional[torch.Tensor] = None, - condition_video_indicator: Optional[torch.Tensor] = None, - condition_video_input_mask: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - condition_video_pose: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Args: - - condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation - condition_video_pose: (B, 1, T, H, W) tensor of pose condition - """ - B, C, T, H, W = x.shape - - if data_type == DataType.VIDEO: - assert ( - condition_video_input_mask is not None - ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" - if self.cp_group is not None: - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views - ) - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=3, cp_group=self.cp_group - ) - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views - ) - input_list = [x, condition_video_input_mask] - if condition_video_pose is not None: - if condition_video_pose.shape[2] > T: - log.warning( - f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" - ) - condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() - input_list.append(condition_video_pose) - x = torch.cat( - input_list, - dim=1, - ) - - return super().forward( # general_dit.GeneralDIT.forward() - x=x, - timesteps=timesteps, - crossattn_emb=crossattn_emb, - crossattn_mask=crossattn_mask, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - hint_key=hint_key, - base_model=base_model, - control_weight=control_weight, - num_layers_to_use=num_layers_to_use, - condition_video_augment_sigma=condition_video_augment_sigma, - view_indices_B_T=view_indices_B_T, - **kwargs, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_multi_camera.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_multi_camera.py deleted file mode 100644 index 37d0614404be7fce9649edf5669da6c4db36ed79..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_multi_camera.py +++ /dev/null @@ -1,491 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import nn -from torchvision import transforms - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.module.position_embedding import ( - LearnableEmb3D, - LearnableEmb3D_FPS_Aware, - LearnablePosEmbAxis, - MultiCameraSinCosPosEmbAxis, - MultiCameraVideoRopePosition3DEmb, - SinCosPosEmb, - SinCosPosEmb_FPS_Aware, - SinCosPosEmbAxis, - VideoRopePosition3DEmb, - VideoRopePositionEmb, -) -from cosmos_transfer1.diffusion.training.modules.blocks import ( - GeneralDITTransformerBlock, - PatchEmbed, - SDXLTimestepEmbedding, - SDXLTimesteps, -) -from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT -from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim -from cosmos_transfer1.utils import log - - -class MultiCameraGeneralDIT(GeneralDIT): - def __init__( - self, - *args, - n_views: int = 3, - n_views_emb: int = -1, - view_condition_dim: int = 3, - traj_condition_dim: int = 0, - concat_view_embedding: bool = True, - concat_traj_embedding: bool = False, - add_repeat_frame_embedding: bool = False, - **kwargs, - ): - if kwargs.get("add_augment_sigma_embedding", None) is not None: - kwargs.pop("add_augment_sigma_embedding") - self.n_views = n_views - if n_views_emb < 0: - self.n_views_emb = n_views - else: - self.n_views_emb = n_views_emb - self.view_condition_dim = view_condition_dim - self.concat_view_embedding = concat_view_embedding - self.traj_condition_dim = traj_condition_dim - self.concat_traj_embedding = concat_traj_embedding - self.add_repeat_frame_embedding = add_repeat_frame_embedding - - super().__init__(*args, **kwargs) - # reinit self.blocks - del self.blocks - self.blocks = nn.ModuleDict() - - layer_mask = [False] * self.num_blocks if kwargs["layer_mask"] is None else kwargs["layer_mask"] - assert ( - len(layer_mask) == self.num_blocks - ), f"Layer mask length {len(layer_mask)} does not match num_blocks { self.num_blocks}" - for idx in range(self.num_blocks): - if layer_mask[idx]: - continue - self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( - x_dim=self.model_channels, - context_dim=kwargs["crossattn_emb_channels"], - num_heads=self.num_heads, - block_config=self.block_config, - window_sizes=( - kwargs["window_sizes"] if idx in kwargs["window_block_indexes"] else [] - ), # There will be bug if using "WA-CA-MLP" - mlp_ratio=kwargs["mlp_ratio"], - spatial_attn_win_size=kwargs["spatial_attn_win_size"], - temporal_attn_win_size=kwargs["temporal_attn_win_size"], - x_format=self.block_x_format, - use_adaln_lora=self.use_adaln_lora, - adaln_lora_dim=self.adaln_lora_dim, - n_views=self.n_views, - ) - self.view_embeddings = nn.Embedding(self.n_views_emb, view_condition_dim) # Learnable embedding layer - - if self.concat_traj_embedding: - self.traj_embeddings = nn.Linear(192, self.traj_condition_dim) # Learnable embedding layer - if self.add_repeat_frame_embedding: - self.repeat_frame_embedding = nn.Linear(1, view_condition_dim) # Learnable embedding layer - - self.init_weights() - - def build_patch_embed(self): - ( - concat_padding_mask, - in_channels, - patch_spatial, - patch_temporal, - model_channels, - view_condition_dim, - traj_condition_dim, - ) = ( - self.concat_padding_mask, - self.in_channels, - self.patch_spatial, - self.patch_temporal, - self.model_channels, - self.view_condition_dim, - self.traj_condition_dim, - ) - if self.concat_view_embedding: - in_channels = in_channels + view_condition_dim if view_condition_dim > 0 else in_channels - - if self.concat_traj_embedding: - in_channels = in_channels + traj_condition_dim if traj_condition_dim > 0 else in_channels - - in_channels = in_channels + 1 if concat_padding_mask else in_channels - - self.x_embedder = PatchEmbed( - spatial_patch_size=patch_spatial, - temporal_patch_size=patch_temporal, - in_channels=in_channels, - out_channels=model_channels, - bias=False, - keep_spatio=True, - legacy_patch_emb=self.legacy_patch_emb, - ) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) - if self.legacy_patch_emb: - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - - def build_pos_embed(self): - if self.pos_emb_cls == "rope3d": - cls_type = MultiCameraVideoRopePosition3DEmb - else: - raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") - - log.critical(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") - kwargs = dict( - model_channels=self.model_channels, - len_h=self.max_img_h // self.patch_spatial, - len_w=self.max_img_w // self.patch_spatial, - len_t=self.max_frames // self.patch_temporal, - max_fps=self.max_fps, - min_fps=self.min_fps, - is_learnable=self.pos_emb_learnable, - interpolation=self.pos_emb_interpolation, - head_dim=self.model_channels // self.num_heads, - h_extrapolation_ratio=self.rope_h_extrapolation_ratio, - w_extrapolation_ratio=self.rope_w_extrapolation_ratio, - t_extrapolation_ratio=self.rope_t_extrapolation_ratio, - n_views=self.n_views, - ) - self.pos_embedder = cls_type( - **kwargs, - ) - - if self.extra_per_block_abs_pos_emb: - kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio - kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio - kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio - self.extra_pos_embedder = MultiCameraSinCosPosEmbAxis( - **kwargs, - ) - - def forward_before_blocks( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - """ - trajectory = kwargs.get("trajectory", None) - frame_repeat = kwargs.get("frame_repeat", None) - - del kwargs - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - original_shape = x.shape - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - trajectory=trajectory, - frame_repeat=frame_repeat, - view_indices_B_T=view_indices_B_T, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() - - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = self.affline_norm(affline_emb_B_D) - - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - output = { - "x": x, - "affline_emb_B_D": affline_emb_B_D, - "crossattn_emb": crossattn_emb, - "crossattn_mask": crossattn_mask, - "rope_emb_L_1_1_D": rope_emb_L_1_1_D, - "adaln_lora_B_3D": adaln_lora_B_3D, - "original_shape": original_shape, - "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - } - return output - - def prepare_embedded_sequence( - self, - x_B_C_T_H_W: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - trajectory: Optional[torch.Tensor] = None, - frame_repeat: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - """ - Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. - - Args: - x_B_C_T_H_W (torch.Tensor): video - fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. - If None, a default value (`self.base_fps`) will be used. - padding_mask (Optional[torch.Tensor]): current it is not used - - Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: - - A tensor of shape (B, T, H, W, D) with the embedded sequence. - - An optional positional embedding tensor, returned only if the positional embedding class - (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. - - Notes: - - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. - - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. - - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using - the `self.pos_embedder` with the shape [T, H, W]. - - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder` - with the fps tensor. - - Otherwise, the positional embeddings are generated without considering fps. - """ - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) - - if view_indices_B_T is None: - view_indices = torch.arange(self.n_views).clamp(max=self.n_views_emb - 1) # View indices [0, 1, ..., V-1] - view_indices = view_indices.to(x_B_C_T_H_W.device) - view_embedding = self.view_embeddings(view_indices) # Shape: [V, embedding_dim] - view_embedding = rearrange(view_embedding, "V D -> D V") - view_embedding = ( - view_embedding.unsqueeze(0).unsqueeze(3).unsqueeze(4).unsqueeze(5) - ) # Shape: [1, D, V, 1, 1, 1] - else: - view_indices_B_T = view_indices_B_T.clamp(max=self.n_views_emb - 1) - view_indices_B_T = view_indices_B_T.to(x_B_C_T_H_W.device).long() - view_embedding = self.view_embeddings(view_indices_B_T) # B, (V T), D - view_embedding = rearrange(view_embedding, "B (V T) D -> B D V T", V=self.n_views) - view_embedding = view_embedding.unsqueeze(-1).unsqueeze(-1) # Shape: [B, D, V, T, 1, 1] - - if self.add_repeat_frame_embedding: - if frame_repeat is None: - frame_repeat = ( - torch.zeros([x_B_C_T_H_W.shape[0], view_embedding.shape[1]]) - .to(view_embedding.device) - .to(view_embedding.dtype) - ) - frame_repeat_embedding = self.repeat_frame_embedding(frame_repeat.unsqueeze(-1)) - frame_repeat_embedding = rearrange(frame_repeat_embedding, "B V D -> B D V") - view_embedding = view_embedding + frame_repeat_embedding.unsqueeze(3).unsqueeze(4).unsqueeze(5) - - x_B_C_V_T_H_W = rearrange(x_B_C_T_H_W, "B C (V T) H W -> B C V T H W", V=self.n_views) - view_embedding = view_embedding.expand( - x_B_C_V_T_H_W.shape[0], - view_embedding.shape[1], - view_embedding.shape[2], - x_B_C_V_T_H_W.shape[3], - x_B_C_V_T_H_W.shape[4], - x_B_C_V_T_H_W.shape[5], - ) # Shape: [B, V, 3, t, H, W] - if self.concat_traj_embedding: - traj_emb = self.traj_embeddings(trajectory) - traj_emb = traj_emb.unsqueeze(2).unsqueeze(3).unsqueeze(4).unsqueeze(5) - traj_emb = traj_emb.expand( - x_B_C_V_T_H_W.shape[0], - traj_emb.shape[1], - view_embedding.shape[2], - x_B_C_V_T_H_W.shape[3], - x_B_C_V_T_H_W.shape[4], - x_B_C_V_T_H_W.shape[5], - ) # Shape: [B, V, 3, t, H, W] - - x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding, traj_emb], dim=1) - else: - x_B_C_V_T_H_W = torch.cat([x_B_C_V_T_H_W, view_embedding], dim=1) - - x_B_C_T_H_W = rearrange(x_B_C_V_T_H_W, " B C V T H W -> B C (V T) H W", V=self.n_views) - x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) - - if self.extra_per_block_abs_pos_emb: - extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) - else: - extra_pos_emb = None - - if "rope" in self.pos_emb_cls.lower(): - return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb - - if "fps_aware" in self.pos_emb_cls: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D] - else: - x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D] - return x_B_T_H_W_D, None, extra_pos_emb - - -class VideoExtendGeneralDIT(MultiCameraGeneralDIT): - def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): - self.add_augment_sigma_embedding = add_augment_sigma_embedding - # extra channel for video condition mask - super().__init__(*args, in_channels=in_channels, **kwargs) - log.info(f"VideoExtendGeneralDIT in_channels: {in_channels }") - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - video_cond_bool: Optional[torch.Tensor] = None, - condition_video_indicator: Optional[torch.Tensor] = None, - condition_video_input_mask: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - condition_video_pose: Optional[torch.Tensor] = None, - view_indices_B_T: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Args: - condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation - condition_video_pose: (B, 1, T, H, W) tensor of pose condition - """ - B, C, T, H, W = x.shape - - if data_type == DataType.VIDEO: - assert ( - condition_video_input_mask is not None - ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" - if self.cp_group is not None: - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C (V T) H W -> B C V T H W", V=self.n_views - ) - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=3, cp_group=self.cp_group - ) - condition_video_input_mask = rearrange( - condition_video_input_mask, "B C V T H W -> B C (V T) H W", V=self.n_views - ) - input_list = [x, condition_video_input_mask] - if condition_video_pose is not None: - if condition_video_pose.shape[2] > T: - log.warning( - f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" - ) - condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() - input_list.append(condition_video_pose) - x = torch.cat( - input_list, - dim=1, - ) - - return super().forward( - x=x, - timesteps=timesteps, - crossattn_emb=crossattn_emb, - crossattn_mask=crossattn_mask, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_augment_sigma=condition_video_augment_sigma, - view_indices_B_T=view_indices_B_T, - **kwargs, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_video_conditioned.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_video_conditioned.py deleted file mode 100644 index 26b979d1fc86fd7d6336205ac35c2d63cec0f35b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/networks/general_dit_video_conditioned.py +++ /dev/null @@ -1,259 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import nn - -from cosmos_transfer1.diffusion.conditioner import DataType -from cosmos_transfer1.diffusion.module.parallel import split_inputs_cp -from cosmos_transfer1.diffusion.training.modules.blocks import SDXLTimestepEmbedding, SDXLTimesteps -from cosmos_transfer1.diffusion.training.networks.general_dit import GeneralDIT -from cosmos_transfer1.diffusion.training.tensor_parallel import scatter_along_first_dim -from cosmos_transfer1.utils import log - - -class VideoExtendGeneralDIT(GeneralDIT): - def __init__(self, *args, in_channels=16 + 1, add_augment_sigma_embedding=False, **kwargs): - self.add_augment_sigma_embedding = add_augment_sigma_embedding - - # extra channel for video condition mask - super().__init__(*args, in_channels=in_channels, **kwargs) - log.info(f"VideoExtendGeneralDIT in_channels: {in_channels}") - - def build_additional_timestamp_embedder(self): - super().build_additional_timestamp_embedder() - if self.add_augment_sigma_embedding: - log.info("Adding augment sigma embedding") - self.augment_sigma_embedder = nn.Sequential( - SDXLTimesteps(self.model_channels), - SDXLTimestepEmbedding(self.model_channels, self.model_channels, use_adaln_lora=self.use_adaln_lora), - ) - - def init_weights(self): - if self.add_augment_sigma_embedding: - # Initialize timestep embedding for augment sigma - nn.init.normal_(self.augment_sigma_embedder[1].linear_1.weight, std=0.02) - if self.augment_sigma_embedder[1].linear_1.bias is not None: - nn.init.constant_(self.augment_sigma_embedder[1].linear_1.bias, 0) - nn.init.normal_(self.augment_sigma_embedder[1].linear_2.weight, std=0.02) - if self.augment_sigma_embedder[1].linear_2.bias is not None: - nn.init.constant_(self.augment_sigma_embedder[1].linear_2.bias, 0) - - super().init_weights() # Call this last since it wil call TP weight init - - def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - video_cond_bool: Optional[torch.Tensor] = None, - condition_video_indicator: Optional[torch.Tensor] = None, - condition_video_input_mask: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - condition_video_pose: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Args: - condition_video_augment_sigma: (B) tensor of sigma value for the conditional input augmentation - condition_video_pose: (B, 1, T, H, W) tensor of pose condition - """ - B, C, T, H, W = x.shape - - if data_type == DataType.VIDEO: - assert ( - condition_video_input_mask is not None - ), "condition_video_input_mask is required for video data type; check if your model_obj is extend_model.FSDPDiffusionModel or the base DiffusionModel" - if self.cp_group is not None: - condition_video_input_mask = split_inputs_cp( - condition_video_input_mask, seq_dim=2, cp_group=self.cp_group - ) - condition_video_indicator = split_inputs_cp( - condition_video_indicator, seq_dim=2, cp_group=self.cp_group - ) - if condition_video_pose is not None: - condition_video_pose = split_inputs_cp(condition_video_pose, seq_dim=2, cp_group=self.cp_group) - - input_list = [x, condition_video_input_mask] - if condition_video_pose is not None: - if condition_video_pose.shape[2] > T: - log.warning( - f"condition_video_pose has more frames than the input video: {condition_video_pose.shape} > {x.shape}" - ) - condition_video_pose = condition_video_pose[:, :, :T, :, :].contiguous() - input_list.append(condition_video_pose) - x = torch.cat( - input_list, - dim=1, - ) - - if data_type == DataType.IMAGE: - # For image, we dont have condition_video_input_mask, or condition_video_pose - # We need to add the extra channel for video condition mask - padding_channels = self.in_channels - x.shape[1] - x = torch.cat([x, torch.zeros((B, padding_channels, T, H, W), dtype=x.dtype, device=x.device)], dim=1) - else: - assert x.shape[1] == self.in_channels, f"Expected {self.in_channels} channels, got {x.shape[1]}" - return super().forward( - x=x, - timesteps=timesteps, - crossattn_emb=crossattn_emb, - crossattn_mask=crossattn_mask, - fps=fps, - image_size=image_size, - padding_mask=padding_mask, - scalar_feature=scalar_feature, - data_type=data_type, - condition_video_augment_sigma=condition_video_augment_sigma, - **kwargs, - ) - - def forward_before_blocks( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Args: - x: (B, C, T, H, W) tensor of spatial-temp inputs - timesteps: (B, ) tensor of timesteps - crossattn_emb: (B, N, D) tensor of cross-attention embeddings - crossattn_mask: (B, N) tensor of cross-attention masks - - condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation - """ - del kwargs - assert isinstance( - data_type, DataType - ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." - original_shape = x.shape - x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( - x, - fps=fps, - padding_mask=padding_mask, - latent_condition=latent_condition, - latent_condition_sigma=latent_condition_sigma, - ) - # logging affline scale information - affline_scale_log_info = {} - - timesteps_B_D, adaln_lora_B_3D = self.t_embedder(timesteps.flatten()) - affline_emb_B_D = timesteps_B_D - affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() - - if scalar_feature is not None: - raise NotImplementedError("Scalar feature is not implemented yet.") - timesteps_B_D = timesteps_B_D + scalar_feature.mean(dim=1) - if self.additional_timestamp_channels: - additional_cond_B_D = self.prepare_additional_timestamp_embedder( - bs=x.shape[0], - fps=fps, - h=image_size[:, 0], - w=image_size[:, 1], - org_h=image_size[:, 2], - org_w=image_size[:, 3], - ) - - affline_emb_B_D += additional_cond_B_D - affline_scale_log_info["additional_cond_B_D"] = additional_cond_B_D.detach() - if self.add_augment_sigma_embedding: - if condition_video_augment_sigma is None: - # Handling image case - # Note: for video case, when there is not condition frames, we also set it as zero, see - # the augment_conditional_latent_frames function in DiffusionV2WModel and ExtendDiffusionModel. - assert data_type == DataType.IMAGE, "condition_video_augment_sigma is required for video data type" - condition_video_augment_sigma = torch.zeros_like(timesteps.flatten()) - - affline_augment_sigma_emb_B_D, adaln_lora_sigma_emb_B_3D = self.augment_sigma_embedder( - condition_video_augment_sigma.flatten() - ) - affline_emb_B_D = affline_emb_B_D + affline_augment_sigma_emb_B_D - affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() - affline_emb_B_D = self.affline_norm(affline_emb_B_D) - - # for logging purpose - self.affline_scale_log_info = affline_scale_log_info - self.affline_emb = affline_emb_B_D - self.crossattn_emb = crossattn_emb - self.crossattn_mask = crossattn_mask - - if self.use_cross_attn_mask: - crossattn_mask = crossattn_mask[:, None, None, :].to(dtype=torch.bool) # [B, 1, 1, length] - else: - crossattn_mask = None - - if self.blocks["block0"].x_format == "THWBD": - x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" - ) - crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") - if crossattn_mask: - crossattn_mask = rearrange(crossattn_mask, "B M -> M B") - - if self.sequence_parallel: - tp_group = parallel_state.get_tensor_model_parallel_group() - # Sequence parallel requires the input tensor to be scattered along the first dimension. - assert self.block_config == "FA-CA-MLP" # Only support this block config for now - T, H, W, B, D = x.shape - # variable name x_T_H_W_B_D is no longer valid. x is reshaped to THW*1*1*b*D and will be reshaped back in FinalLayer - x = x.view(T * H * W, 1, 1, B, D) - assert x.shape[0] % parallel_state.get_tensor_model_parallel_world_size() == 0 - x = scatter_along_first_dim(x, tp_group) - - if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.view( - T * H * W, 1, 1, B, D - ) - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = scatter_along_first_dim( - extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, tp_group - ) - - elif self.blocks["block0"].x_format == "BTHWD": - x = x_B_T_H_W_D - else: - raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") - output = { - "x": x, - "affline_emb_B_D": affline_emb_B_D, - "crossattn_emb": crossattn_emb, - "crossattn_mask": crossattn_mask, - "rope_emb_L_1_1_D": rope_emb_L_1_1_D, - "adaln_lora_B_3D": adaln_lora_B_3D, - "original_shape": original_shape, - "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, - } - return output diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/tensor_parallel.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/tensor_parallel.py deleted file mode 100644 index c756c38e53d2c71f3ab3fa2b08859fdf1bb96bc5..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/tensor_parallel.py +++ /dev/null @@ -1,102 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch -import torch.distributed as dist -from torch.autograd import Function - - -class AllGather(Function): - @staticmethod - def forward(ctx, tensor, process_group): - world_size = dist.get_world_size(process_group) - ctx.world_size = world_size - ctx.rank = process_group.rank() - - gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] - dist.all_gather(gathered_tensors, tensor.contiguous(), process_group) - return torch.cat(gathered_tensors, dim=0) - - @staticmethod - def backward(ctx, grad_output): - world_size = ctx.world_size - rank = ctx.rank - - # Split the gradient tensor - grad_chunks = grad_output.chunk(world_size) - - # Select the gradient chunk for the current rank - grad_input = grad_chunks[rank] - return grad_input, None - - -def gather_along_first_dim(tensor, process_group): - return AllGather.apply(tensor, process_group) - - -class Scatter(Function): - @staticmethod - def forward(ctx, tensor, process_group): - world_size = dist.get_world_size(process_group) - ctx.world_size = world_size - ctx.process_group = process_group - rank = process_group.rank() - - # Split the tensor - tensor_chunks = tensor.chunk(world_size) - - # Select the tensor chunk for the current rank - return tensor_chunks[rank] - - @staticmethod - def backward(ctx, grad_output): - world_size = ctx.world_size - process_group = ctx.process_group - - # Gather the gradient tensor - gathered_grads = [torch.zeros_like(grad_output) for _ in range(world_size)] - dist.all_gather(gathered_grads, grad_output.contiguous(), process_group) - return torch.cat(gathered_grads, dim=0), None - - -def scatter_along_first_dim(tensor, process_group): - return Scatter.apply(tensor, process_group) - - -if __name__ == "__main__": - # Torch global setup for distributed training - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - torch.cuda.set_device(local_rank) - torch.distributed.init_process_group(world_size=world_size, rank=rank) - - # Create a tensor with gradients - x = torch.randn(10, 1, requires_grad=True, device="cuda") - - # Perform all_gather with gradient support - y = gather_along_first_dim(x, dist.group.WORLD) - print(f"{y.shape=}") - y = scatter_along_first_dim(y, dist.group.WORLD) - print(f"{y.shape=}") - - # Use the result in your computation - loss = y.sum() - loss.backward() - - # x.grad now contains the gradients - print(x.grad) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/train.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/train.py deleted file mode 100644 index 298b231ad273985d8b3a48f7df95af6fa144678b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/train.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import importlib -import os -import time - -import torch -import torch.distributed as dist -from loguru import logger as logging -from megatron.core import parallel_state -from omegaconf import OmegaConf - -from cosmos_transfer1.diffusion.config.config import Config -from cosmos_transfer1.utils import log, misc -from cosmos_transfer1.utils.config_helper import get_config_module, override -from cosmos_transfer1.utils.lazy_config import instantiate -from cosmos_transfer1.utils.lazy_config.lazy import LazyConfig -from cosmos_transfer1.utils.parallel_state_helper import is_tp_cp_pp_rank0 - - -@misc.timer("instantiate model") -def instantiate_model(config: Config, trainer) -> None: - misc.set_random_seed(seed=config.trainer.seed, by_rank=False) - config.model_obj.config = config.model - if getattr(config.model, "fsdp_enabled", False): - assert config.trainer.distributed_parallelism == "fsdp", "FSDP model is only supported with FSDP trainer" - log.critical("FSDP enabled") - config.model_obj.fsdp_checkpointer = trainer.checkpointer - model = instantiate(config.model_obj) - config.model_obj.fsdp_checkpointer = None - else: - model = instantiate(config.model_obj) - config.model_obj.config = None - misc.set_random_seed(seed=config.trainer.seed, by_rank=True) - return model - - -def destroy_distributed(): - log.info("Destroying distributed environment...") - if dist.is_available() and dist.is_initialized(): - try: - if parallel_state.get_tensor_model_parallel_world_size() > 1: - dist.barrier(group=parallel_state.get_tensor_model_parallel_group()) - # Global barrier across all ranks - dist.barrier() - - torch.cuda.empty_cache() - dist.destroy_process_group() - except ValueError as e: - print(f"Error destroying default process group: {e}") - - -@logging.catch(reraise=True) -def launch(config: Config, args: argparse.Namespace) -> None: - # Check that the config is valid - config.validate() - if config.trainer.timestamp_seed: - # Get the current time in microseconds - current_time = int(time.time() * 1e6) - # Combine the current time with worker_id to ensure different seeds across workers - seed = current_time % (2**32) - config.trainer.seed = seed - log.critical(f"Changed Random Seed based on timestamp. {config.trainer.seed}") - - # Freeze the config so developers don't change it during training. - config.freeze() # type: ignore - trainer = config.trainer.type(config) - # # Setup the miscellaneous stuff for reproducibility. - # log_reproducible_setup(config, args) - # Create the model - model = instantiate_model(config, trainer) - model.on_model_init_end() - # Create the dataloaders. - if args.mp0_only_dl: - log.critical( - "Using only tp_cp_pp_rank0 dataloader for faster dataloading! Make sure val dl is mock and mock data has same keys as real data." - ) - raise NotImplementedError( - "mp0_only_dl is not implemented correctly! Please revisit this code and propose a more robust impl that raise error timely! It does not do necessary check before training to confirm it can work with image / video data. Current impl is problematic for image training." - ) - if is_tp_cp_pp_rank0() or not args.mp0_only_dl: - dataloader_train = instantiate(config.dataloader_train) - else: - dataloader_train = instantiate(config.dataloader_val) - dataloader_val = instantiate(config.dataloader_val) - # Start training - trainer.train( - model, - dataloader_train, - dataloader_val, - ) - destroy_distributed() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Training") - parser.add_argument( - "--config", - default="cosmos_transfer1/diffusion/posttrain/config/config.py", - help="Path to the config file", - ) - parser.add_argument( - "opts", - help=""" -Modify config options at the end of the command. For Yacs configs, use -space-separated "PATH.KEY VALUE" pairs. -For python-based LazyConfig, use "path.key=value". - """.strip(), - default=None, - nargs=argparse.REMAINDER, - ) - parser.add_argument( - "--dryrun", - action="store_true", - help="Do a dry run without training. Useful for debugging the config.", - ) - parser.add_argument( - "--mp0_only_dl", - action="store_true", - help="Use only model parallel rank 0 dataloader for faster dataloading! Make sure mock data has same keys as real data.", - ) - args = parser.parse_args() - config_module = get_config_module(args.config) - config = importlib.import_module(config_module).make_config() - config = override(config, args.opts) - if args.dryrun: - os.makedirs(config.job.path_local, exist_ok=True) - LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") - print(OmegaConf.to_yaml(OmegaConf.load(f"{config.job.path_local}/config.yaml"))) - print(f"{config.job.path_local}/config.yaml") - else: - # Launch the training job. - launch(config, args) diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/utils/fsdp_helper.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/utils/fsdp_helper.py deleted file mode 100644 index 1027504a3ccde8df96c683bc38a8fdfeb0a4003e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/utils/fsdp_helper.py +++ /dev/null @@ -1,159 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from contextlib import contextmanager -from functools import partial - -import torch -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, - apply_activation_checkpointing, - checkpoint_wrapper, -) -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp._runtime_utils import ( - _post_forward, - _post_forward_reshard, - _pre_forward, - _pre_forward_unshard, - _root_pre_forward, -) -from torch.distributed.utils import _p_assert - -from cosmos_transfer1.utils import distributed, log - - -def apply_fsdp_checkpointing(model, list_block_cls): - """apply activation checkpointing to model - returns None as model is updated directly - """ - log.critical("--> applying fdsp activation checkpointing...") - non_reentrant_wrapper = partial( - checkpoint_wrapper, - # offload_to_cpu=False, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - - def check_fn(submodule): - result = False - for block_cls in list_block_cls: - if isinstance(submodule, block_cls): - result = True - break - return result - - apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) - - -@contextmanager -def possible_fsdp_scope( - model: torch.nn.Module, -): - enabled = isinstance(model, FSDP) - if enabled: - assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled" - handle = model._handle - args, kwargs = [0], dict(dummy=0) - with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"): - args, kwargs = _root_pre_forward(model, model, args, kwargs) - unused = None - args, kwargs = _pre_forward( - model, - handle, - _pre_forward_unshard, - model._fsdp_wrapped_module, - args, - kwargs, - ) - if handle: - _p_assert( - handle.flat_param.device == model.compute_device, - "Expected `FlatParameter` to be on the compute device " - f"{model.compute_device} but got {handle.flat_param.device}", - ) - try: - yield None - finally: - if enabled: - output = {"output": 1} - _post_forward(model, handle, _post_forward_reshard, model, unused, output) - - -def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None): - """ - Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. - - This function requires explicit sizes for replica and sharding groups to accommodate models - whose GPU fit is unknown, providing flexibility in distributed training setups. - - Args: - replica_group_size (int): The size of each replica group. Must be provided to ensure - the model fits within the available resources. - sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to - ensure the correct distribution of model parameters. - device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" - with the local rank as the device index. - - Returns: - A device mesh object compatible with FSDP. - - Raises: - ValueError: If replica_group_size or sharding_group_size are not provided, or if the - world size is not evenly divisible by the sharding group size. - RuntimeError: If a valid device mesh cannot be created. - - Usage: - If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: - Sharding_Group_Size = 4 - Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups - >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) - >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) - """ - - # world_size = int(os.getenv("WORLD_SIZE", "1")) - world_size = distributed.get_world_size() - if sharding_group_size is None: - sharding_group_size = min(world_size, 8) - sharding_group_size = min(sharding_group_size, world_size) - if replica_group_size is None: - replica_group_size = world_size // sharding_group_size - - device = device or "cuda" - - if world_size % sharding_group_size != 0: - raise ValueError( - f"World size {world_size} is not evenly divisible by " f"sharding group size {sharding_group_size}." - ) - - if (world_size // sharding_group_size) % replica_group_size != 0: - raise ValueError( - f"The calculated number of replica groups is not evenly divisible by " - f"replica_group_size {replica_group_size}." - ) - - device_mesh = init_device_mesh( - device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard") - ) - if device_mesh is None: - raise RuntimeError("Failed to create a valid device mesh.") - - log.critical( - f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}" - ) - - return device_mesh diff --git a/cosmos-transfer1/cosmos_transfer1/diffusion/training/utils/optim_instantiate.py b/cosmos-transfer1/cosmos_transfer1/diffusion/training/utils/optim_instantiate.py deleted file mode 100644 index c12dca27bcbbb622a398b9ffdde4d2ff75ddb88b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/diffusion/training/utils/optim_instantiate.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import hydra -import torch -from torch import nn - -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.fused_adam import FusedAdam - - -def get_regular_param_group(net: nn.Module): - """ - seperate the parameters of the network into two groups: decay and no_decay. - based on nano_gpt codebase. - """ - param_dict = {pn: p for pn, p in net.named_parameters()} - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - - decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] - nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] - return decay_params, nodecay_params - - -def get_base_optimizer( - model: nn.Module, - lr: float, - weight_decay: float, - optim_type: str = "adamw", - sharding: bool = False, - **kwargs, -) -> torch.optim.Optimizer: - net_decay_param, net_nodecay_param = get_regular_param_group(model) - - num_decay_params = sum(p.numel() for p in net_decay_param) - num_nodecay_params = sum(p.numel() for p in net_nodecay_param) - net_param_total = num_decay_params + num_nodecay_params - log.critical(f"total num parameters : {net_param_total:,}") - - param_group = [ - { - "params": net_decay_param + net_nodecay_param, - "lr": lr, - "weight_decay": weight_decay, - }, - ] - - if optim_type == "adamw": - opt_cls = torch.optim.AdamW - elif optim_type == "fusedadam": - opt_cls = FusedAdam - else: - raise ValueError(f"Unknown optimizer type: {optim_type}") - - return opt_cls(param_group, **kwargs) - - -def get_base_scheduler( - optimizer: torch.optim.Optimizer, - model: nn.Module, - scheduler_config: dict, -): - net_scheduler = hydra.utils.instantiate(scheduler_config) - net_scheduler.model = model - - return torch.optim.lr_scheduler.LambdaLR( - optimizer, - lr_lambda=[ - net_scheduler.schedule, - ], - ) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/__init__.py b/cosmos-transfer1/cosmos_transfer1/utils/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/utils/base_world_generation_pipeline.py b/cosmos-transfer1/cosmos_transfer1/utils/base_world_generation_pipeline.py deleted file mode 100644 index 418eda45e53e4357fab49efcc19263f7eefc25f4..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/base_world_generation_pipeline.py +++ /dev/null @@ -1,354 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import os -from abc import ABC -from typing import Any - -import numpy as np -import torch - -from cosmos_transfer1.auxiliary.guardrail.common import presets as guardrail_presets -from cosmos_transfer1.checkpoints import T5_MODEL_CHECKPOINT -from cosmos_transfer1.utils.t5_text_encoder import CosmosT5TextEncoder - - -class BaseWorldGenerationPipeline(ABC): - def __init__( - self, - inference_type: str | None = None, - checkpoint_dir: str | None = None, - checkpoint_name: str | None = None, - has_text_input: bool = False, - offload_network: bool = False, - offload_tokenizer: bool = False, - offload_text_encoder_model: bool = False, - offload_guardrail_models: bool = False, - ): - """Initialize base world generation pipeline. - - This abstract base class provides core functionality for world generation models including: - - Model loading and initialization - - Text encoding and embedding - - Safety checks and content filtering - - Memory management through model offloading - - Args: - inference_type: The type of inference pipeline ("text2world" or "video2world") - checkpoint_dir: Root directory containing model checkpoints - checkpoint_name: Name of the specific checkpoint file to load - has_text_input: Whether the pipeline takes text input for world generation - offload_network: If True, moves main model to CPU after inference - offload_tokenizer: If True, moves tokenizer to CPU after use - offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding - offload_guardrail_models: If True, moves safety models to CPU after checks - """ - self.inference_type = inference_type - self.checkpoint_dir = checkpoint_dir - self.checkpoint_name = checkpoint_name - self.has_text_input = has_text_input - - # Add offloading flags - self.offload_network = offload_network - self.offload_tokenizer = offload_tokenizer - self.offload_text_encoder_model = offload_text_encoder_model - self.offload_guardrail_models = offload_guardrail_models - - # Initialize model instances - self.text_guardrail = None - self.video_guardrail = None - self.text_encoder = None - self.model = None - - self._load_model() - - if not self.offload_text_encoder_model: - self._load_text_encoder_model() - if not self.offload_guardrail_models: - if self.has_text_input: - self._load_text_guardrail() - self._load_video_guardrail() - if not self.offload_network: - self._load_network() - if not self.offload_tokenizer: - self._load_tokenizer() - - def _load_tokenizer(self): - pass - - def _load_network(self): - pass - - def _load_model(self, checkpoint_name: str) -> Any: - """Load the world generation model from a checkpoint. - - This abstract method must be implemented by subclasses to load their specific - model architecture and weights. - - Args: - checkpoint_name: Path to the model checkpoint file - - Returns: - The loaded model instance - - Raises: - NotImplementedError: Must be implemented by subclasses - """ - pass - - def _load_text_encoder_model(self): - """Load the T5 text encoder model. - - Initializes and loads the T5 encoder model used for converting text prompts - into embeddings that condition the world generation model. - - Returns: - Loaded T5 text encoder model instance - """ - self.text_encoder = CosmosT5TextEncoder(cache_dir=os.path.join(self.checkpoint_dir, T5_MODEL_CHECKPOINT)) - - def _load_text_guardrail(self): - """Load text safety classifier models. - - Initializes models used for checking input prompts against safety policies. - Models are loaded from the specified guardrail directory. - """ - self.text_guardrail = guardrail_presets.create_text_guardrail_runner(checkpoint_dir=self.checkpoint_dir) - - def _load_video_guardrail(self): - """Load video safety classifier models. - - Initializes models used for validating generated video content against - safety policies. Models are loaded from the specified guardrail directory. - """ - self.video_guardrail = guardrail_presets.create_video_guardrail_runner(checkpoint_dir=self.checkpoint_dir) - - def _offload_network(self): - if self.model.model: - del self.model.model - self.model.model = None - gc.collect() - torch.cuda.empty_cache() - - def _offload_tokenizer(self): - if self.model.tokenizer: - del self.model.tokenizer - self.model.tokenizer = None - gc.collect() - torch.cuda.empty_cache() - - def _offload_guardrail_models(self): - """Offload safety classifier models to reduce memory usage. - - Moves safety models to CPU and clears GPU memory if they are no longer needed. - This helps manage memory when processing multiple inputs sequentially. - """ - if self.text_guardrail: - del self.text_guardrail - self.text_guardrail = None - if self.video_guardrail: - del self.video_guardrail - self.video_guardrail = None - gc.collect() - torch.cuda.empty_cache() - - def _offload_text_encoder_model(self): - """Offload T5 text encoder to reduce memory usage. - - Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete. - This helps manage memory when processing multiple inputs sequentially. - """ - if self.text_encoder: - del self.text_encoder - self.text_encoder = None - gc.collect() - torch.cuda.empty_cache() - - def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor: - """Generate world latents using the model. - - This abstract method must be implemented by subclasses to define their specific - generation process. - - Args: - *args: Variable positional arguments for model inference - **kwargs: Variable keyword arguments for model inference - - Returns: - torch.Tensor: Generated world representation tensor - """ - pass - - def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor: - """Generate world representation with memory management. - - Handles loading the model before inference and offloading afterward if enabled. - This helps minimize GPU memory usage during inference. - - Args: - *args: Arguments passed to _run_model - **kwargs: Keyword arguments passed to _run_model - - Returns: - np.ndarray: Generated world representation as numpy array - """ - pass - - def _run_guardrail_on_prompt(self, prompt: str) -> bool: - """Check if prompt meets safety requirements. - - Validates the input prompt against safety policies using loaded guardrail models. - - Args: - prompt: Raw text prompt to validate - - Returns: - bool: True if prompt passes all safety checks, False otherwise - """ - return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail) - - def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool: - """Check prompt safety with memory management. - - Validates prompt safety while handling model loading/offloading to manage memory. - - Args: - prompt: Raw text prompt to validate - - Returns: - bool: True if prompt passes all safety checks, False otherwise - """ - if self.offload_guardrail_models: - self._load_text_guardrail() - - is_safe = self._run_guardrail_on_prompt(prompt) - - if self.offload_guardrail_models: - self._offload_guardrail_models() - - return is_safe - - def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None: - """Check if video meets safety requirements. - - Validates generated video content against safety policies using guardrail models. - - Args: - video: Video frames to validate - - Returns: - np.ndarray: Processed video if safe, None if unsafe - """ - return guardrail_presets.run_video_guardrail(video, self.video_guardrail) - - def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None: - """Check if generated video meets safety requirements. - - Args: - video: Video frames to validate - - Returns: - np.ndarray: Processed video frames if safe, None otherwise - - Note: - Guardrail models are offloaded after checks if enabled. - """ - if self.offload_guardrail_models: - self._load_video_guardrail() - - video = self._run_guardrail_on_video(video) - - if self.offload_guardrail_models: - self._offload_guardrail_models() - return video - - def _run_text_embedding_on_prompt( - self, prompts: list[str], **kwargs: Any - ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """Convert text prompts to embeddings. - - Processes text prompts into embedding tensors that condition the generation model. - - Args: - prompts: List of text prompts to encode - **kwargs: Additional arguments for text encoding - - Returns: - tuple containing: - - List of text embedding tensors for each prompt - - List of attention masks for each embedding - """ - - embeddings = [] - masks = [] - for prompt in prompts: - embedding, mask = self.text_encoder.encode_prompts( - [prompt], - **kwargs, - ) - embeddings.append(embedding) - masks.append(mask) - - return embeddings, masks - - def _run_text_embedding_on_prompt_with_offload( - self, prompts: list[str], **kwargs: Any - ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """Convert text prompt into embeddings using T5 encoder. - - Args: - prompt: Processed and validated text prompt - - Returns: - Text embedding tensor to condition diffusion model - - Note: - T5 model is offloaded after encoding if enabled. - """ - if self.offload_text_encoder_model: - self._load_text_encoder_model() - - embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs) - - if self.offload_text_encoder_model: - self._offload_text_encoder_model() - return embeddings, masks - - def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray: - """Decode model outputs into final world representation. - - This abstract method must be implemented by subclasses to convert raw model - outputs into their specific world representation format. - - Args: - samples: Raw output tensor from the generation model - - Returns: - np.ndarray: Decoded world representation - """ - pass - - def generate(self, *args: Any, **kwargs: Any): - """Generate world representation. - - This abstract method must be implemented by subclasses to convert raw model - outputs into their specific world representation format. - - Args: - *args: Variable positional arguments for model inference - **kwargs: Variable keyword arguments for model inference - """ - pass diff --git a/cosmos-transfer1/cosmos_transfer1/utils/callback.py b/cosmos-transfer1/cosmos_transfer1/utils/callback.py deleted file mode 100644 index 0b669a922787d54e882768529c7d8ba35b1cd282..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/callback.py +++ /dev/null @@ -1,456 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import time -import warnings -from typing import TYPE_CHECKING, Any, Callable, List, Optional - -import omegaconf -import torch -import torch.utils.data -import tqdm -from megatron.core import parallel_state -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from cosmos_transfer1.utils import distributed, log -from cosmos_transfer1.utils.lazy_config import instantiate -from cosmos_transfer1.utils.misc import get_local_tensor_if_DTensor - -if TYPE_CHECKING: - from cosmos_transfer1.utils.config import Config - from cosmos_transfer1.utils.model import Model - from cosmos_transfer1.utils.trainer import Trainer - - -class CallBackGroup: - """A class for hosting a collection of callback objects. - - It is used to execute callback functions of multiple callback objects with the same method name. - When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs - self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match. - - Attributes: - _callbacks (list[Callback]): List of callback objects. - """ - - def __init__(self, config: Config, trainer: Trainer) -> None: - """Initializes the list of callback objects. - - Args: - config (Config): The config object for the codebase. - trainer (Trainer): The main trainer. - """ - self._callbacks = [] - callback_configs = config.trainer.callbacks - if callback_configs: - if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig): - warnings.warn( - "The 'config.trainer.callbacks' parameter should be a dict instead of a list. " - "Please update your code", - DeprecationWarning, - stacklevel=2, - ) - callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)} - for callback_name, current_callback_cfg in callback_configs.items(): - if "_target_" not in current_callback_cfg: - log.critical( - f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}" - ) - continue - log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}") - _callback = instantiate(current_callback_cfg) - assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback." - _callback.config = config - _callback.trainer = trainer - self._callbacks.append(_callback) - - def __getattr__(self, method_name: str) -> Callable: - """Loops through the callback objects to call the corresponding callback function. - - Args: - method_name (str): Callback method name. - """ - - def multi_callback_wrapper(*args, **kwargs) -> None: - for callback in self._callbacks: - assert hasattr(callback, method_name) - method = getattr(callback, method_name) - assert callable(method) - _ = method(*args, **kwargs) - - return multi_callback_wrapper - - -class Callback: - """The base class for all callbacks. - - All callbacks should inherit from this class and adhere to the established method names and signatures. - """ - - def __init__(self, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): - """Initializes a Callback object. - - Args: - config (Optional[Config]): The configuration object for the codebase, if available. - trainer (Optional[Trainer]): The main trainer handling the training loop, if available. - - Notes: - The config and trainer parameters are optional to maintain backward compatibility. - In future releases, these parameters will be removed. Upon using these parameters, a deprecation - warning will be issued. - - """ - if config is not None or trainer is not None: - warnings.warn( - "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. " - "Please update your code to create Callback instances without these parameters.", - DeprecationWarning, - stacklevel=2, - ) - del config, trainer - - def on_train_start(self, model: Model, iteration: int = 0) -> None: - pass - - def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - pass - - def on_before_forward(self, iteration: int = 0) -> None: - pass - - def on_after_forward(self, iteration: int = 0) -> None: - pass - - def on_before_backward( - self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0 - ) -> None: - pass - - def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None: - pass - - def on_before_dataloading(self, iteration: int = 0) -> None: - pass - - def on_after_dataloading(self, iteration: int = 0) -> None: - pass - - def on_optimizer_init_start(self) -> None: - pass - - def on_optimizer_init_end(self) -> None: - pass - - def on_before_optimizer_step( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int = 0, - ) -> None: - pass - - def on_before_zero_grad( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - iteration: int = 0, - ) -> None: - pass - - def on_training_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - pass - - def on_validation_start( - self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 - ) -> None: - pass - - def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - pass - - def on_validation_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - pass - - def on_validation_end(self, model: Model, iteration: int = 0) -> None: - pass - - def on_load_checkpoint_start(self, model: Model) -> None: - pass - - def on_load_checkpoint_end(self, model: Model) -> None: - pass - - def on_load_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: - pass - - def on_save_checkpoint_start(self, model: Model, iteration: int = 0) -> None: - pass - - def on_save_checkpoint_end(self, model: Model, iteration: int = 0) -> None: - pass - - def on_save_checkpoint_success(self, iteration: int = 0) -> None: - pass - - def on_save_checkpoint(self, model: Model, state_dict: dict[Any]) -> None: - pass - - def on_train_end(self, model: Model, iteration: int = 0) -> None: - pass - - def on_app_end(self) -> None: - pass - - -class EMAModelCallback(Callback): - """The callback class for tracking EMA model weights.""" - - def on_train_start(self, model: Model, iteration: int = 0) -> None: - # Set up the EMA model weight tracker. - if model.config.ema.enabled: - assert hasattr(model, "ema"), "EMA should be initialized from Model" - # EMA model must be kept in FP32 precision. - model.ema = model.ema.to(dtype=torch.float32) - else: - assert not hasattr(model, "ema"), "There should be no EMA initialized." - - def on_training_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - # Update the EMA model with the new regular weights. - if model.config.ema.enabled: - model.ema.update_average(model, iteration) - - -class ProgressBarCallback(Callback): - """The callback class for visualizing the training/validation progress bar in the console.""" - - @distributed.rank0_only - def on_train_start(self, model: Model, iteration: int = 0) -> None: - self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") - - @distributed.rank0_only - def on_training_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - self.train_pbar.update() - - @distributed.rank0_only - def on_validation_start( - self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 - ) -> None: - if self.config.trainer.max_val_iter is not None: - num_iter = self.config.trainer.max_val_iter - else: - num_iter = len(dataloader_val) - assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}" - self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False) - - @distributed.rank0_only - def on_validation_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - self.val_pbar.update() - - @distributed.rank0_only - def on_validation_end(self, model: Model, iteration: int = 0) -> None: - self.val_pbar.close() - - @distributed.rank0_only - def on_train_end(self, model: Model, iteration: int = 0) -> None: - self.trainer.checkpointer.finalize() - self.train_pbar.close() - - -class IterationLoggerCallback(Callback): - """The callback class for visualizing the training/validation progress bar in the console.""" - - @distributed.rank0_only - def on_train_start(self, model: Model, iteration: int = 0) -> None: - # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") - self.start_iteration_time = time.time() - self.elapsed_iteration_time = 0 - - @distributed.rank0_only - def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - self.start_iteration_time = time.time() - - @distributed.rank0_only - def on_training_step_end( - self, - model: Model, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - self.elapsed_iteration_time += time.time() - self.start_iteration_time - - if iteration % self.config.trainer.logging_iter == 0: - avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter - log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}") - - self.elapsed_iteration_time = 0 - - -@torch.jit.script -def _fused_nan_to_num(params: List[torch.Tensor]): - for param in params: - torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param) - - -class GradClip(Callback): - def __init__( - self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False - ): - self.clip_norm = clip_norm - self.force_finite = force_finite - self.model_key = model_key - self.fsdp_enabled = fsdp_enabled - - def on_before_optimizer_step( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int = 0, - ) -> None: - del optimizer, scheduler - if isinstance(model_ddp, distributed.DistributedDataParallel): - model = model_ddp.module - else: - model = model_ddp - - # select sub-network if specified - if self.model_key is not None: - items = self.model_key.split(".") - for item in items: - model = getattr(model, item) - - if self.force_finite: - params = [] - for param in model.parameters(): - if param.grad is not None: - params.append(param.grad) - # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad) - _fused_nan_to_num(params) - - # check if FSDP is used - # total_norm - if isinstance(model, FSDP) and self.fsdp_enabled: - model.clip_grad_norm_(self.clip_norm) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True) - - -class GradClipCallback(Callback): - """The callback class for gradient clipping.""" - - def __init__( - self, - config: Optional["Config"] = None, - trainer: Optional["Trainer"] = None, - grad_clip_norm: float = 1.0, - ): - super().__init__(config, trainer) - self.grad_clip_norm = grad_clip_norm - - def on_before_optimizer_step( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int = 0, - ) -> None: - grad_scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model_ddp.module.parameters(), max_norm=self.grad_clip_norm) - - -class LowPrecisionCallback(Callback): - """The callback class handling low precision training""" - - def __init__(self, update_iter: int, config: Optional["Config"] = None, trainer: Optional["Trainer"] = None): - super().__init__(config, trainer) - self.update_iter = update_iter - - def on_train_start(self, model: Model, iteration: int = 0) -> None: - assert model.precision in [ - torch.bfloat16, - torch.float16, - torch.half, - ], "LowPrecisionCallback must use a low precision dtype." - self.precision_type = model.precision - - def on_training_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - for k, v in data.items(): - if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): - data[k] = v.to(dtype=self.precision_type) - - def on_validation_step_start(self, model: Model, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - for k, v in data.items(): - if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): - data[k] = v.to(dtype=self.precision_type) - - def on_before_zero_grad( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - iteration: int = 0, - ) -> None: - if iteration % self.update_iter == 0: - if getattr(optimizer, "master_weights", False): - params, master_params = [], [] - for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master): - for p, p_master in zip(group["params"], group_master["params"]): - params.append(get_local_tensor_if_DTensor(p.data)) - master_params.append(p_master.data) - torch._foreach_copy_(params, master_params) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/checkpointer.py b/cosmos-transfer1/cosmos_transfer1/utils/checkpointer.py deleted file mode 100644 index 2c8617ff11e5ddfa69a0391986e79d0d20b7000c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/checkpointer.py +++ /dev/null @@ -1,237 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -import threading -from typing import TYPE_CHECKING - -import torch - -from cosmos_transfer1.utils import callback, distributed, log, misc -from cosmos_transfer1.utils.model import Model - -if TYPE_CHECKING: - from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig - - -class Checkpointer: - """The checkpointer class. Supports checkpoint saving/loading to local disk.""" - - def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): - """Constructor of the checkpointer. - - Args: - config_checkpoint (CheckpointConfig): The config object for the checkpointer. - """ - # Set the callback functions. - self.callbacks = callbacks - self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" - self.strict_resume = config_checkpoint.strict_resume - self.load_path = config_checkpoint.load_path or None - self.load_training_state = config_checkpoint.load_training_state - self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state - self.save_thread = None - - def save( - self, - model: Model, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. - - Args: - model (Model): The PyTorch model. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - iteration (int): Current iteration number. - """ - self.callbacks.on_save_checkpoint_start(model, iteration) - - checkpoint_file = f"iter_{iteration:09}.pt" - - if distributed.get_rank() == 0: - state_dict = dict( - model=model.state_dict(), - optimizer=optimizer.state_dict(), - scheduler=scheduler.state_dict(), - grad_scaler=grad_scaler.state_dict(), - iteration=iteration, - ) - state_dict = misc.to(state_dict, device="cpu") - self.callbacks.on_save_checkpoint(model, state_dict=state_dict) - # Wait for previous saver thread to end. - if self.save_thread: - self.save_thread.join() - # Run the checkpoint saver in a separate thread. - self.save_thread = threading.Thread( - target=self._save_worker_local, - daemon=False, - args=(state_dict, checkpoint_file, distributed.get_rank()), - ) - self.save_thread.start() - - # Note: Checkpoints are saved on a separate thread and this callback is not accurate. - # Please check logs from on_save_checkpoint_success() for better accuracy - self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) - - @misc.timer("checkpoint saving (local)") - def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None: - """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). - - Args: - state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. - checkpoint_file (str): The file name of the model checkpoint. - rank (int): GPU device (default: 0). - """ - checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) - os.makedirs(self.checkpoint_dir_local, exist_ok=True) - try: - torch.save(state_dict, checkpoint_path) - if rank == 0: - self._write_latest_checkpoint_file(checkpoint_file) - log.success(f"Saved checkpoint (local): {checkpoint_path}") - iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) - self.callbacks.on_save_checkpoint_success(iteration=iteration) - except Exception as e: # noqa: BLE001 - log.exception(f"Checkpoint failed to save (local): {e}") - - @misc.timer("checkpoint loading") - def load( - self, - model: Model, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, - grad_scaler: torch.amp.GradScaler | None = None, - ) -> int: - """Load network weights and optimizer states from a checkpoint in a single process. - - The priority of the checkpoint loading logic is: - 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. - 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. - - This is typically used for inference mode. - - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. - 3. If none of the above, randomly initialize the model parameters and train from scratch. - - Args: - model (Model): The PyTorch model. - optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). - scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). - grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). - - Returns: - iteration (int): the iteration number to start/resume from. - """ - self.callbacks.on_load_checkpoint_start(model) - - latest_checkpoint_file = self._read_latest_checkpoint_file() - if latest_checkpoint_file is not None: - # 1. Resume training from latest_checkpoint.txt under the same name. - checkpoint_dir = self.checkpoint_dir_local - checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) - resume = True - only_resume_scheduler = True - else: - if self.load_path: - # 2. Load the module weights specified by config_checkpoint.path. - checkpoint_path = self.load_path - resume = self.load_training_state - only_resume_scheduler = self.only_load_scheduler_state - else: - # 3. Randomly initialize the model parameters and train from scratch. - checkpoint_path = None - resume = False - only_resume_scheduler = False - # Load checkpoint. - if checkpoint_path is not None: - self._check_checkpoint_exists(checkpoint_path) - log.info(f"Loading checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage, weights_only=False) - log.success(f"Complete loading checkpoint (local): {checkpoint_path}") - self.callbacks.on_load_checkpoint(model, state_dict=state_dict) - # Load the state dicts. - log.info("- Loading the model...") - if "model" in state_dict: - model.load_state_dict(state_dict["model"], strict=self.strict_resume) - else: - model.load_state_dict(state_dict, strict=self.strict_resume) - if resume or only_resume_scheduler: - iteration = state_dict["iteration"] - assert scheduler - log.info("- Loading the scheduler...") - scheduler.load_state_dict(state_dict["scheduler"]) - scheduler.last_epoch = iteration - else: - iteration = 0 - if resume: - assert optimizer - log.info("- Loading the optimizer...") - optimizer.load_state_dict(state_dict["optimizer"]) - log.info("- Loading the gradient scaler...") - grad_scaler.load_state_dict(state_dict["grad_scaler"]) - log.success(f"Done with loading the checkpoint (iteration {iteration}).") - else: - log.success("Done with loading the checkpoint.") - else: - # Checkpoint not found and not specified. We will train everything from scratch. - iteration = 0 - log.info("Training from scratch.") - torch.cuda.empty_cache() - - self.callbacks.on_load_checkpoint_end(model) - - return iteration - - def _read_latest_checkpoint_file(self) -> str | None: - """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. - - Returns: - checkpoint_file (str | None): file name of the latest saved checkpoint. - """ - checkpoint_file = None - latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") - if os.path.isfile(latest_path): - checkpoint_file = open(latest_path).read().strip() - return checkpoint_file - - def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: - """Track the file name of the latest saved checkpoint. - - Args: - checkpoint_file (str): file name of the latest saved checkpoint. - """ - content = f"{checkpoint_file}\n" - latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") - with open(latest_path, "w") as file: - file.write(content) - - def _check_checkpoint_exists(self, checkpoint_path: str) -> None: - """If the file checkpoint_path does not exist, raise an error. - - Args: - checkpoint_path (str): full path to the checkpoint. - """ - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"File not found (local): {checkpoint_path}") - - def finalize(self) -> None: - """Finalize the checkpointer.""" - if self.save_thread: - self.save_thread.join() diff --git a/cosmos-transfer1/cosmos_transfer1/utils/config.py b/cosmos-transfer1/cosmos_transfer1/utils/config.py deleted file mode 100644 index 85253c9443e0b5239cbebabf47b5c958060a2561..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/config.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -from typing import Any, Dict, Optional, Type, TypeVar, Union - -import attrs -import torch - -try: - from megatron.core import ModelParallelConfig - - USE_MEGATRON = True -except ImportError: - USE_MEGATRON = False - print("Megatron-core is not installed.") - -from cosmos_transfer1.utils.callback import EMAModelCallback, ProgressBarCallback -from cosmos_transfer1.utils.ddp_config import DDPConfig, make_freezable -from cosmos_transfer1.utils.lazy_config import LazyCall as L -from cosmos_transfer1.utils.lazy_config import LazyDict -from cosmos_transfer1.utils.misc import Color - - -def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: - """ - Recursively pretty prints attrs objects with color. - """ - - assert attrs.has(obj.__class__) - - lines: list[str] = [] - for attribute in attrs.fields(obj.__class__): - value = getattr(obj, attribute.name) - if attrs.has(value.__class__): - if use_color: - lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") - else: - lines.append(" " * indent + "* " + attribute.name + ":") - lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) - else: - if use_color: - lines.append( - " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) - ) - else: - lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) - return "\n".join(lines) - - -@make_freezable -@attrs.define(slots=False) -class JobConfig: - # Project name. - project: str = "" - # Experiment name. - group: str = "" - # Run/job name. - name: str = "" - - @property - def path(self) -> str: - return f"{self.project}/{self.group}/{self.name}" - - @property - def path_local(self) -> str: - local_root = os.environ.get("OUTPUT_ROOT", "checkpoints") - return f"{local_root}/{self.path}" - - -@make_freezable -@attrs.define(slots=False) -class EMAConfig: - # Enable tracking a set of exponential moving average (EMA) weights. - enabled: bool = False - # EMA decay rate. - beta: float = 0.9999 - # Enable removing "_orig_mod-" from buffer names that is added by torch.compile - torch_compile_buffer_renaming: bool = False - - -@make_freezable -@attrs.define(slots=False) -class CuDNNConfig: - # Set to True for better reproducibility of the results (only using deterministic cudnn functions). - deterministic: bool = False - # If set to True, cudnn will benchmark several algorithms and pick the fastest one. - benchmark: bool = True - - -@make_freezable -@attrs.define(slots=False) -class JITConfig: - # Enable exporting a JIT compiled model. - enabled: bool = False - # Input tensor shape, for example input. - input_shape: Union[list[int], None] = None - # Device to compile onto. - device: str = "cuda" - # # Data type to compile onto. - dtype: str = "bfloat16" - # Strict mode for PyTorch JIT. - strict: bool = True - - -@make_freezable -@attrs.define(slots=False) -class CheckpointConfig: - # possible checkpoint class - type: Optional[Dict] = None - # for dcp, whether to use async mode - dcp_async_mode_enabled: bool = False - # Save the checkpoint every N iterations. - save_iter: int = 999999999 - # Path of model weights to resume the checkpoint from. - load_path: str = "" - # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path. - load_training_state: bool = False - # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored. - only_load_scheduler_state: bool = False - # Load state_dict to the models in strict mode. - strict_resume: bool = True - # Print detailed information during checkpoint saving/loading. - verbose: bool = True - # Configs for JIT compiling EMA model. - jit: JITConfig = attrs.field(factory=JITConfig) - # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"] - keys_not_to_resume: list[str] = [] - # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer). - broadcast_via_filesystem: bool = False - load_ema_to_reg: bool = False - - -@make_freezable -@attrs.define(slots=False) -class TrainerConfig: - from cosmos_transfer1.utils.trainer import Trainer - - type: Type[Trainer] = Trainer - # Set the callback class. - # Defaults to the callbacks below. - callbacks: LazyDict = LazyDict( - dict( - ema=L(EMAModelCallback)(), - progress_bar=L(ProgressBarCallback)(), - ) - ) - # distributed parallelism strategy - distributed_parallelism: str = "ddp" - # Distributed data parallel configs. - ddp: DDPConfig = attrs.field(factory=DDPConfig) - # cuDNN configs. - cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) - # Set the random seed. - seed: int = 0 - # Gradient scaler arguments (for torch.amp.GradScaler). - grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) - # Maximum number of iterations to train the model. - max_iter: int = 999999999 - # Maximum number of iterations to validate the model. If None, validate on the entire dataset. - max_val_iter: int | None = None - # How often we log the training stats. - logging_iter: int = 100 - # Whether we want to run the validation routines. - run_validation: bool = True - # How often we evaluate on the validation set. - validation_iter: int = 999999999 - # Kill the process after N seconds since the last iteration (usually means dead job). - timeout_period: int = 999999999 - # Tensor memory organization format. - memory_format: torch.memory_format = torch.preserve_format - # Gradient accumulation (update step every N iteration). - grad_accum_iter: int = 1 - # Whether to use the timestamp as the seed. Needed to ensure real randomness in loading data. - timestamp_seed: bool = True - # # Profiling config - # profiling: Profiling = attrs.field(factory=Profiling) - - -@make_freezable -@attrs.define(slots=False) -class Config: - """Config for a job. - - See /README.md/Configuration System for more info. - """ - - # Model configs. - model: LazyDict - # Optimizer configs. - optimizer: LazyDict = LazyDict(dict(dummy=None)) - # Scheduler configs. - scheduler: LazyDict = LazyDict(dict(dummy=None)) - # Training data configs. - dataloader_train: LazyDict = LazyDict(dict(dummy=None)) - # Validation data configs. - dataloader_val: LazyDict = LazyDict(dict(dummy=None)) - - # Training job configs. - job: JobConfig = attrs.field(factory=JobConfig) - - # Trainer configs. - trainer: TrainerConfig = attrs.field(factory=TrainerConfig) - - # Megatron-Core configs - if USE_MEGATRON: - # Megatron-Core configs - model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) - else: - model_parallel: None = None - - # Checkpointer configs. - checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) - - def pretty_print(self, use_color: bool = False) -> str: - return _pretty_print_attrs_instance(self, 0, use_color) - - def to_dict(self) -> dict[str, Any]: - return attrs.asdict(self) - - def validate(self) -> None: - """Validate that the config has all required fields.""" - assert self.job.project != "", "Project name is required." - assert self.job.group != "", "Group name is required." - assert self.job.name != "", "Job name is required." diff --git a/cosmos-transfer1/cosmos_transfer1/utils/config_helper.py b/cosmos-transfer1/cosmos_transfer1/utils/config_helper.py deleted file mode 100644 index 4a056195ab148cc4034e24655dd252af4d80d5ff..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/config_helper.py +++ /dev/null @@ -1,201 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import os -import pkgutil -import sys -from dataclasses import fields as dataclass_fields -from dataclasses import is_dataclass -from typing import Any, Dict, Optional - -import attr -import attrs -from hydra import compose, initialize -from hydra.core.config_store import ConfigStore -from hydra.core.global_hydra import GlobalHydra -from omegaconf import DictConfig, OmegaConf - -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.config import Config - - -def is_attrs_or_dataclass(obj) -> bool: - """ - Check if the object is an instance of an attrs class or a dataclass. - - Args: - obj: The object to check. - - Returns: - bool: True if the object is an instance of an attrs class or a dataclass, False otherwise. - """ - return is_dataclass(obj) or attr.has(type(obj)) - - -def get_fields(obj): - """ - Get the fields of an attrs class or a dataclass. - - Args: - obj: The object to get fields from. Must be an instance of an attrs class or a dataclass. - - Returns: - list: A list of field names. - - Raises: - ValueError: If the object is neither an attrs class nor a dataclass. - """ - if is_dataclass(obj): - return [field.name for field in dataclass_fields(obj)] - elif attr.has(type(obj)): - return [field.name for field in attr.fields(type(obj))] - else: - raise ValueError("The object is neither an attrs class nor a dataclass.") - - -def override(config: Config, overrides: Optional[list[str]] = None) -> Config: - """ - :param config: the instance of class `Config` (usually from `make_config`) - :param overrides: list of overrides for config - :return: the composed instance of class `Config` - """ - # Store the class of the config for reconstruction after overriding. - # config_class = type(config) - - # Convert Config object to a DictConfig object - config_dict = attrs.asdict(config) - config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) - # Enforce "--" separator between the script arguments and overriding configs. - if overrides: - if overrides[0] != "--": - raise ValueError('Hydra config overrides must be separated with a "--" token.') - overrides = overrides[1:] - # Use Hydra to handle overrides - cs = ConfigStore.instance() - cs.store(name="config", node=config_omegaconf) - if not GlobalHydra().is_initialized(): - with initialize(version_base=None): - config_omegaconf = compose(config_name="config", overrides=overrides) - OmegaConf.resolve(config_omegaconf) - else: - config_omegaconf = compose(config_name="config", overrides=overrides) - OmegaConf.resolve(config_omegaconf) - - def config_from_dict(ref_instance: Any, kwargs: Any) -> Any: - """ - Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data - - Args: - ref_instance: The reference instance to determine the type and fields when needed - kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data - - Returns: - Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data - - Raises: - AssertionError: If the fields do not match or if extra keys are found. - Exception: If there is an error constructing the new instance. - """ - is_type = is_attrs_or_dataclass(ref_instance) - if not is_type: - return kwargs - else: - ref_fields = set(get_fields(ref_instance)) - assert isinstance(kwargs, dict) or isinstance( - kwargs, DictConfig - ), "kwargs must be a dictionary or a DictConfig" - keys = set(kwargs.keys()) - - # ref_fields must equal to or include all keys - extra_keys = keys - ref_fields - assert ref_fields == keys or keys.issubset( - ref_fields - ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}" - - resolved_kwargs: Dict[str, Any] = {} - for f in keys: - resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f]) - try: - new_instance = type(ref_instance)(**resolved_kwargs) - except Exception as e: - log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}") - log.error(e) - raise e - return new_instance - - config = config_from_dict(config, config_omegaconf) - - return config - - -def get_config_module(config_file: str) -> str: - if not config_file.endswith(".py"): - log.error("Config file cannot be specified as module.") - log.error("Please provide the path to the Python config file (relative to the Cosmos root).") - assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found." - # Convert to importable module format. - config_module = config_file.replace("/", ".").replace(".py", "") - return config_module - - -def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None: - """ - Import all modules from the specified package path recursively. - - This function is typically used in conjunction with Hydra to ensure that all modules - within a specified package are imported, which is necessary for registering configurations. - - Example usage: - ```python - import_all_modules_from_package( "cosmos_transfer1.diffusion.config.inference", reload=True, skip_underscore=False) - ``` - - Args: - package_path (str): The dotted path to the package from which to import all modules. - reload (bool): Flag to determine whether to reload modules if they're already imported. - skip_underscore (bool): If True, skips importing modules that start with an underscore. - """ - log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}") - package = importlib.import_module(package_path) - package_directory = package.__path__ - - def import_modules_recursively(directory: str, prefix: str) -> None: - """ - Recursively imports or reloads all modules in the given directory. - - Args: - directory (str): The file system path to the current package directory. - prefix (str): The module prefix (e.g., 'transfer1.diffusion.config'). - """ - for _, module_name, is_pkg in pkgutil.iter_modules([directory]): - if skip_underscore and module_name.startswith("_"): - log.debug(f"Skipping module {module_name} as it starts with an underscore") - continue - - full_module_name = f"{prefix}.{module_name}" - log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}") - - if full_module_name in sys.modules and reload: - importlib.reload(sys.modules[full_module_name]) - else: - importlib.import_module(full_module_name) - - if is_pkg: - sub_package_directory = os.path.join(directory, module_name) - import_modules_recursively(sub_package_directory, full_module_name) - - for directory in package_directory: - import_modules_recursively(directory, package_path) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/ddp_config.py b/cosmos-transfer1/cosmos_transfer1/utils/ddp_config.py deleted file mode 100644 index 00e4045d066d948b6fed782fc3182abaa14390e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/ddp_config.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TypeVar - -import attrs - -T = TypeVar("T") - - -def _is_attrs_instance(obj: object) -> bool: - """ - Helper function to check if an object is an instance of an attrs-defined class. - - Args: - obj: The object to check. - - Returns: - bool: True if the object is an instance of an attrs-defined class, False otherwise. - """ - return hasattr(obj, "__attrs_attrs__") - - -def make_freezable(cls: T) -> T: - """ - A decorator that adds the capability to freeze instances of an attrs-defined class. - - NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need - to hack on a "_is_frozen" attribute. - - This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. - Once an instance is frozen, its attributes cannot be changed. It also recursively freezes - any attrs-defined objects that are attributes of the class. - - Usage: - @make_freezable - @attrs.define(slots=False) - class MyClass: - attribute1: int - attribute2: str - - obj = MyClass(1, 'a') - obj.freeze() # Freeze the instance - obj.attribute1 = 2 # Raises AttributeError - - Args: - cls: The class to be decorated. - - Returns: - The decorated class with added freezing capability. - """ - - if not hasattr(cls, "__dict__"): - raise TypeError( - "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " - "class was defined with `@attrs.define(slots=False)`" - ) - - original_setattr = cls.__setattr__ - - def setattr_override(self, key, value) -> None: # noqa: ANN001 - """ - Override __setattr__ to allow modifications during initialization - and prevent modifications once the instance is frozen. - """ - if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": - raise AttributeError("Cannot modify frozen instance") - original_setattr(self, key, value) # type: ignore - - cls.__setattr__ = setattr_override # type: ignore - - def freeze(self: object) -> None: - """ - Freeze the instance and all its attrs-defined attributes. - """ - for _, value in attrs.asdict(self, recurse=False).items(): - if _is_attrs_instance(value) and hasattr(value, "freeze"): - value.freeze() - self._is_frozen = True # type: ignore - - cls.freeze = freeze # type: ignore - - return cls - - -@make_freezable -@attrs.define(slots=False) -class DDPConfig: - # Traverse the computation graph to find parameters that don't receive gradients. - find_unused_parameters: bool = False - # Set to True if the computation graph does not change during the whole training loop. - static_graph: bool = True - # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. - broadcast_buffers: bool = True diff --git a/cosmos-transfer1/cosmos_transfer1/utils/device.py b/cosmos-transfer1/cosmos_transfer1/utils/device.py deleted file mode 100644 index db486afabd4ae0bf11feb05d8a4efd96690ce64b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/device.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import os - -import pynvml - - -class Device: - """A class to handle NVIDIA GPU device operations using NVML. - - This class provides an interface to access and manage NVIDIA GPU devices, - including retrieving device information and CPU affinity settings. - - Attributes: - _nvml_affinity_elements (int): Number of 64-bit elements needed to represent CPU affinity - """ - - _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore - - def __init__(self, device_idx: int): - """Initialize a Device instance for a specific GPU. - - Args: - device_idx (int): Index of the GPU device to manage - - Raises: - NVMLError: If the device cannot be found or initialized - """ - super().__init__() - self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) - - def get_cpu_affinity(self) -> list[int]: - """Get the CPU affinity mask for this GPU device. - - Retrieves the CPU affinity mask indicating which CPU cores are assigned - to this GPU device. The affinity is returned as a list of CPU core indices. - - Returns: - list[int]: List of CPU core indices that have affinity with this GPU - - Raises: - NVMLError: If the CPU affinity information cannot be retrieved - - Example: - >>> device = Device(0) - >>> device.get_cpu_affinity() - [0, 1, 2, 3] # Shows this GPU has affinity with CPU cores 0-3 - """ - affinity_string = "" - for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): - # assume nvml returns list of 64 bit ints - affinity_string = "{:064b}".format(j) + affinity_string - affinity_list = [int(x) for x in affinity_string] - affinity_list.reverse() # so core 0 is in 0th element of list - return [i for i, e in enumerate(affinity_list) if e != 0] diff --git a/cosmos-transfer1/cosmos_transfer1/utils/distributed.py b/cosmos-transfer1/cosmos_transfer1/utils/distributed.py deleted file mode 100644 index 522265a5890bf5b0ec49551e04661d6ad2d65492..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/distributed.py +++ /dev/null @@ -1,312 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import collections -import collections.abc -import ctypes -import functools -import os -from contextlib import contextmanager -from datetime import timedelta -from typing import Any, Callable, Optional, TypeVar - -import pynvml -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.ddp_config import DDPConfig -from cosmos_transfer1.utils.device import Device - -try: - from megatron.core import parallel_state -except ImportError: - print("Megatron-core is not installed.") - -T = TypeVar("T") - - -def init() -> int | None: - """Initialize distributed training.""" - # Set GPU affinity. - pynvml.nvmlInit() - local_rank = int(os.getenv("LOCAL_RANK", 0)) - device = Device(local_rank) - os.sched_setaffinity(0, device.get_cpu_affinity()) - # Set up NCCL communication. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" - if dist.is_available(): - if dist.is_initialized(): - return torch.cuda.current_device() - torch.cuda.set_device(local_rank) - # Get the timeout value from environment variable - timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) - # Convert the timeout to an integer (if it isn't already) and then to a timedelta - timeout_timedelta = timedelta(seconds=int(timeout_seconds)) - dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) - log.critical( - f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}", - rank0_only=False, - ) - # Increase the L2 fetch granularity for faster speed. - _libcudart = ctypes.CDLL("libcudart.so") - # Set device limit on the current device. - p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) - _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) - _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) - log.info(f"Training with {get_world_size()} GPUs.") - - -def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: - """Get the rank (GPU device) of the worker. - - Returns: - rank (int): The rank of the worker. - """ - rank = 0 - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank(group) - return rank - - -def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: - """Get world size. How many GPUs are available in this job. - - Returns: - world_size (int): The total number of GPUs available in this job. - """ - world_size = 1 - if dist.is_available() and dist.is_initialized(): - world_size = dist.get_world_size(group) - return world_size - - -def is_rank0() -> bool: - """Check if current process is the master GPU. - - Returns: - (bool): True if this function is called from the master GPU, else False. - """ - return get_rank() == 0 - - -def rank0_only(func: Callable) -> Callable: - """Apply this function only to the master GPU. - - Example usage: - @rank0_only - def func(x): - return x + 3 - - Args: - func (Callable): a function. - - Returns: - (Callable): A function wrapper executing the function only on the master GPU. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - if is_rank0(): - return func(*args, **kwargs) - else: - return None - - return wrapper - - -def barrier() -> None: - """Barrier for all GPUs.""" - if dist.is_available() and dist.is_initialized(): - dist.barrier() - - -def rank0_first(func: Callable) -> Callable: - """run the function on rank 0 first, then on other ranks.""" - - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - if is_rank0(): - result = func(*args, **kwargs) - barrier() - if not is_rank0(): - result = func(*args, **kwargs) - return result - - return wrapper - - -class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): - """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). - - This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an coreModel such that - model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling - model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> - training_step), allowing us to preserve the function names and signatures. - """ - - def __init__(self, model: torch.nn.Module, *args, **kwargs): - super().__init__(model, *args, **kwargs) - - def training_step(self, *args, **kwargs) -> Any: - # Cache the original model.forward() method. - original_forward = self.module.forward - - def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 - # Unpatch immediately before calling training_step() because itself may want to call the real forward. - self.module.forward = original_forward - # The actual .training_step(). - return self.module.training_step(*_args, **_kwargs) - - # Patch the original_module's forward so we can redirect the arguments back to the real method. - self.module.forward = wrapped_training_step - # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). - # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. - return self(*args, **kwargs) - - -def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: - """Wraps the model to enable data parallalism for training across multiple GPU devices. - - Args: - config_ddp (DDPConfig): The data parallel config. - model (torch.nn.Module): The PyTorch module. - - Returns: - model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper - if distributed environment is available, otherwise return the original model. - """ - if dist.is_available() and dist.is_initialized(): - local_rank = int(os.getenv("LOCAL_RANK", 0)) - try: - ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) - except Exception as e: - log.info(e) - log.info("parallel_state not initialized, treating all GPUs equally for DDP") - ddp_group = None - - model = DistributedDataParallel( - model, - device_ids=[local_rank], - output_device=local_rank, - find_unused_parameters=config_ddp.find_unused_parameters, - static_graph=config_ddp.static_graph, - broadcast_buffers=config_ddp.broadcast_buffers, - process_group=ddp_group, - ) - return model - - -@contextmanager -def ddp_sync_grad(model, enabled): - r""" - Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. - Modified from: - https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync - Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. - - Within this context, gradients will be accumulated on module - variables, which will later be synchronized in the first - forward-backward pass exiting the context. - - .. warning:: - The forward pass should be included inside the context manager, or - else gradients will still be synchronized. - """ - assert isinstance(model, torch.nn.Module) - if isinstance(model, DistributedDataParallel): - old_require_backward_grad_sync = model.require_backward_grad_sync - if model.static_graph and model.require_backward_grad_sync != enabled: - if model.show_sync_grad_static_graph_warning: - log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") - model.show_sync_grad_static_graph_warning = False - else: - model.require_backward_grad_sync = enabled - try: - yield - finally: - if isinstance(model, DistributedDataParallel): - model.require_backward_grad_sync = old_require_backward_grad_sync - - -def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: - """Aggregate the list of data batches from all devices and process the results. - - This is used for gathering validation data batches with utils.dataloader.DistributedEvalSampler. - It will return the data/output of the entire validation set in its original index order. The sizes of data_batches - in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be - created before calling dis.all_gather(). - - Args: - data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where - leaf entries are tensors. - - Returns: - data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where - leaf entries are concatenated tensors. - """ - if isinstance(data_batches[0], torch.Tensor): - # Concatenate the local data batches. - data_concat = torch.cat(data_batches, dim=0) # type: ignore - # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank. - max_num_local_samples = torch.tensor(len(data_concat), device="cuda") - dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) - if len(data_concat) < max_num_local_samples: - assert len(data_concat) + 1 == max_num_local_samples - dummy = torch.empty_like(data_concat[:1]) - data_concat = torch.cat([data_concat, dummy], dim=0) - dummy_count = torch.tensor(1, device="cuda") - else: - dummy_count = torch.tensor(0, device="cuda") - # Get all concatenated batches from all ranks and concatenate again. - dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) - data_concat = all_gather_tensor(data_concat.contiguous()) - data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) - # Remove the dummy samples. - if dummy_count > 0: - data_collate = data_collate[:-dummy_count] - elif isinstance(data_batches[0], collections.abc.Mapping): - data_collate = dict() - for key in data_batches[0].keys(): - data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore - else: - raise TypeError - return data_collate - - -@torch.no_grad() -def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: - """Gather the corresponding tensor from all GPU devices to a list. - - Args: - tensor (torch.Tensor): Pytorch tensor. - - Returns: - tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. - """ - tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] - dist.all_gather(tensor_list, tensor) - return tensor_list - - -def broadcast(tensor, src, group=None, async_op=False): - world_size = get_world_size() - if world_size < 2: - return tensor - dist.broadcast(tensor, src=src, group=group, async_op=async_op) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/__init__.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/__init__.py deleted file mode 100644 index 3159bfe65645499015bd92609b99d476d69544e9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/__init__.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/__init__.py deleted file mode 100644 index 86481ffa7a207636fabe4c37715644a55012130e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend -from cosmos_transfer1.utils.easy_io.backends.http_backend import HTTPBackend -from cosmos_transfer1.utils.easy_io.backends.local_backend import LocalBackend -from cosmos_transfer1.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend - -__all__ = [ - "BaseStorageBackend", - "LocalBackend", - "HTTPBackend", - "register_backend", - "backends", - "prefix_to_backends", -] diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/base_backend.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/base_backend.py deleted file mode 100644 index 2db3b921f0b6fdb3aaea867c0bb3cafdb5e59888..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/base_backend.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import os.path as osp -from abc import ABCMeta, abstractmethod - - -def mkdir_or_exist(dir_name, mode=0o777): - if dir_name == "": - return - dir_name = osp.expanduser(dir_name) - os.makedirs(dir_name, mode=mode, exist_ok=True) - - -def has_method(obj, method): - return hasattr(obj, method) and callable(getattr(obj, method)) - - -class BaseStorageBackend(metaclass=ABCMeta): - """Abstract class of storage backends. - - All backends need to implement two apis: :meth:`get()` and - :meth:`get_text()`. - - - :meth:`get()` reads the file as a byte stream. - - :meth:`get_text()` reads the file as texts. - """ - - # a flag to indicate whether the backend can create a symlink for a file - # This attribute will be deprecated in future. - _allow_symlink = False - - @property - def allow_symlink(self): - return self._allow_symlink - - @property - def name(self): - return self.__class__.__name__ - - @abstractmethod - def get(self, filepath): - pass - - @abstractmethod - def get_text(self, filepath): - pass diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/http_backend.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/http_backend.py deleted file mode 100644 index 8ed642517717cc4531a92a44c1a3b84a18924ba0..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/http_backend.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -from contextlib import contextmanager -from pathlib import Path -from typing import Generator, Union -from urllib.request import urlopen - -from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend - - -class HTTPBackend(BaseStorageBackend): - """HTTP and HTTPS storage bachend.""" - - def get(self, filepath: str) -> bytes: - """Read bytes from a given ``filepath``. - - Args: - filepath (str): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get('http://path/of/file') - b'hello world' - """ - return urlopen(filepath).read() - - def get_text(self, filepath, encoding="utf-8") -> str: - """Read text from a given ``filepath``. - - Args: - filepath (str): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get_text('http://path/of/file') - 'hello world' - """ - return urlopen(filepath).read().decode(encoding) - - @contextmanager - def get_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` to a local temporary directory, - and return the temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str): Download a file from ``filepath``. - - Yields: - Iterable[str]: Only yield one temporary path. - - Examples: - >>> backend = HTTPBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with backend.get_local_path('http://path/of/file') as path: - ... # do something here - """ - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/local_backend.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/local_backend.py deleted file mode 100644 index a99247f973e9ca67dbb6508cfc815691356c2569..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/local_backend.py +++ /dev/null @@ -1,550 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import os -import os.path as osp -import shutil -from contextlib import contextmanager -from pathlib import Path -from typing import Generator, Iterator, Optional, Tuple, Union - -from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist - - -class LocalBackend(BaseStorageBackend): - """Raw local storage backend.""" - - _allow_symlink = True - - def get(self, filepath: Union[str, Path]) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get(filepath) - b'hello world' - """ - with open(filepath, "rb") as f: - value = f.read() - return value - - def get_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get_text(filepath) - 'hello world' - """ - with open(filepath, encoding=encoding) as f: - text = f.read() - return text - - def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put(b'hello world', filepath) - """ - mkdir_or_exist(osp.dirname(filepath)) - if isinstance(obj, io.BytesIO): - obj.seek(0) - obj = obj.getvalue() - with open(filepath, "wb") as f: - f.write(obj) - - def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put_text('hello world', filepath) - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, "w", encoding=encoding) as f: - f.write(obj) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.exists(filepath) - True - """ - return osp.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/dir' - >>> backend.isdir(filepath) - True - """ - return osp.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.isfile(filepath) - True - """ - return osp.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - - Examples: - >>> backend = LocalBackend() - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> backend.join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - return osp.join(filepath, *filepaths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, Path], - ) -> Generator[Union[str, Path], None, None]: - """Only for unified API and does nothing. - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> backend = LocalBackend() - >>> with backend.get_local_path('abc/def.jpg') as path: - ... # do something here - """ - yield filepath - - def copyfile( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> backend.copyfile(src, dst) - '/path1/of/dir/file' - """ - return shutil.copy(src, dst) - - def copytree( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree(src, dst) - '/path/of/dir2' - """ - return shutil.copytree(src, dst) - - def copyfile_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a local file src to dst and return the destination file. Same - as :meth:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_from_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_from_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. Same as - :meth:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def copyfile_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - dst_type: Optional[str] = None, - ) -> str: - """Copy the file src to local dst and return the destination file. Same - as :meth:`copyfile`. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_to_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_to_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - - Raises: - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.remove(filepath) - """ - if not self.exists(filepath): - raise FileNotFoundError(f"filepath {filepath} does not exist") - - if self.isdir(filepath): - raise IsADirectoryError("filepath should be a file") - - os.remove(filepath) - - def rmtree(self, dir_path: Union[str, Path]) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - - Examples: - >>> dir_path = '/path/of/dir' - >>> backend.rmtree(dir_path) - """ - shutil.rmtree(dir_path) - - def copy_if_symlink_fails( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directly copy src - to dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - - Returns: - bool: Return True if successfully create a symbolic link pointing - to src. Otherwise, return False. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> backend.copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> backend.copy_if_symlink_fails(src, dst) - True - """ - try: - os.symlink(src, dst) - return True - except Exception: - if self.isfile(src): - self.copyfile(src, dst) - else: - self.copytree(src, dst) - return False - - def list_dir_or_file( - self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> backend = LocalBackend() - >>> dir_path = '/path/of/dir' - >>> # list those files and directories in current directory - >>> for file_path in backend.list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ # noqa: E501 - if list_dir and suffix is not None: - raise TypeError("`suffix` should be None when `list_dir` is True") - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError("`suffix` must be a string or tuple of strings") - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith(".") and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - if (suffix is None or rel_path.endswith(suffix)) and list_file: - yield rel_path - elif osp.isdir(entry.path): - if list_dir: - rel_dir = osp.relpath(entry.path, root) - yield rel_dir - if recursive: - yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive) - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/registry_utils.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/registry_utils.py deleted file mode 100644 index acd77b13978a6fc73ba6ec8820b7cd0d532d329e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/backends/registry_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Optional, Type, Union - -from cosmos_transfer1.utils.easy_io.backends.base_backend import BaseStorageBackend -from cosmos_transfer1.utils.easy_io.backends.http_backend import HTTPBackend -from cosmos_transfer1.utils.easy_io.backends.local_backend import LocalBackend - -backends: dict = {} -prefix_to_backends: dict = {} - - -def _register_backend( - name: str, - backend: Type[BaseStorageBackend], - force: bool = False, - prefixes: Union[str, list, tuple, None] = None, -): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (BaseStorageBackend): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - """ - global backends, prefix_to_backends - - if not isinstance(name, str): - raise TypeError("the backend name should be a string, " f"but got {type(name)}") - - if not inspect.isclass(backend): - raise TypeError(f"backend should be a class, but got {type(backend)}") - if not issubclass(backend, BaseStorageBackend): - raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") - - if name in backends and not force: - raise ValueError( - f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' - ) - backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - - for prefix in prefixes: - if prefix in prefix_to_backends and not force: - raise ValueError( - f"{prefix} is already registered as a storage backend," - ' add "force=True" if you want to override it' - ) - - prefix_to_backends[prefix] = backend - - -def register_backend( - name: str, - backend: Optional[Type[BaseStorageBackend]] = None, - force: bool = False, - prefixes: Union[str, list, tuple, None] = None, -): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - - This method can be used as a normal method or a decorator. - - Examples: - - >>> class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - >>> register_backend('new', NewBackend) - - >>> @register_backend('new') - ... class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - """ - if backend is not None: - _register_backend(name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - _register_backend(name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - -register_backend("local", LocalBackend, prefixes="") -register_backend("http", HTTPBackend, prefixes=["http", "https"]) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/easy_io.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/easy_io.py deleted file mode 100644 index cb2959aa0d9c8fda8b6a54ed150afd3f1b7f1a58..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/easy_io.py +++ /dev/null @@ -1,1066 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import warnings -from contextlib import contextmanager -from io import BytesIO, StringIO -from pathlib import Path -from typing import IO, Any, Generator, Iterator, Optional, Tuple, Union - -from cosmos_transfer1.utils.easy_io.backends import backends, prefix_to_backends -from cosmos_transfer1.utils.easy_io.file_client import FileClient -from cosmos_transfer1.utils.easy_io.handlers import file_handlers - -backend_instances: dict = {} - - -def is_filepath(filepath): - return isinstance(filepath, (str, Path)) - - -def _parse_uri_prefix(uri: Union[str, Path]) -> str: - """Parse the prefix of uri. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> _parse_uri_prefix('/home/path/of/your/file') - '' - >>> _parse_uri_prefix('s3://path/of/your/file') - 's3' - >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') - 's3' - - Returns: - str: Return the prefix of uri if the uri contains '://'. Otherwise, - return ''. - """ - assert is_filepath(uri) - uri = str(uri) - # if uri does not contains '://', the uri will be handled by - # LocalBackend by default - if "://" not in uri: - return "" - else: - prefix, _ = uri.split("://") - if ":" in prefix: - _, prefix = prefix.split(":") - return prefix - - -def _get_file_backend(prefix: str, backend_args: dict): - """Return a file backend based on the prefix or backend_args. - - Args: - prefix (str): Prefix of uri. - backend_args (dict): Arguments to instantiate the corresponding - backend. - """ - # backend name has a higher priority - if "backend" in backend_args: - # backend_args should not be modified - backend_args_bak = backend_args.copy() - backend_name = backend_args_bak.pop("backend") - backend = backends[backend_name](**backend_args_bak) - else: - backend = prefix_to_backends[prefix](**backend_args) - return backend - - -def get_file_backend( - uri: Union[str, Path, None] = None, - *, - backend_args: Optional[dict] = None, - enable_singleton: bool = False, - backend_key: Optional[str] = None, -): - """Return a file backend based on the prefix of uri or backend_args. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - enable_singleton (bool): Whether to enable the singleton pattern. - If it is True, the backend created will be reused if the - signature is same with the previous one. Defaults to False. - backend_key: str: The key to register the backend. Defaults to None. - - Returns: - BaseStorageBackend: Instantiated Backend object. - - Examples: - >>> # get file backend based on the prefix of uri - >>> uri = 's3://path/of/your/file' - >>> backend = get_file_backend(uri) - >>> # get file backend based on the backend_args - >>> backend = get_file_backend(backend_args={'backend': 's3'}) - >>> # backend name has a higher priority if 'backend' in backend_args - >>> backend = get_file_backend(uri, backend_args={'backend': 's3'}) - """ - global backend_instances - if backend_key is not None: - if backend_key in backend_instances: - return backend_instances[backend_key] - - if backend_args is None: - backend_args = {} - - if uri is None and "backend" not in backend_args and backend_key is None: - raise ValueError( - 'uri should not be None when "backend" does not exist in ' "backend_args and backend_key is None" - ) - - if uri is not None: - prefix = _parse_uri_prefix(uri) - else: - prefix = "" - - if enable_singleton: - unique_key = f"{prefix}:{json.dumps(backend_args)}" - if unique_key in backend_instances: - return backend_instances[unique_key] - - backend = _get_file_backend(prefix, backend_args) - backend_instances[unique_key] = backend - if backend_key is not None: - backend_instances[backend_key] = backend - return backend - else: - backend = _get_file_backend(prefix, backend_args) - return backend - - -def get( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> filepath = '/path/of/file' - >>> get(filepath) - b'hello world' - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.get(filepath) - - -def get_text( - filepath: Union[str, Path], - encoding="utf-8", - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> filepath = '/path/of/file' - >>> get_text(filepath) - 'hello world' - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.get_text(filepath, encoding) - - -def put( - obj: bytes, - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Examples: - >>> filepath = '/path/of/file' - >>> put(b'hello world', filepath) - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.put(obj, filepath) - - -def put_text( - obj: str, - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - ``filepath``. Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Examples: - >>> filepath = '/path/of/file' - >>> put_text('hello world', filepath) - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.put_text(obj, filepath) - - -def exists( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> exists(filepath) - True - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.exists(filepath) - - -def isdir( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/dir' - >>> isdir(filepath) - True - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.isdir(filepath) - - -def isfile( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> isfile(filepath) - True - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.isfile(filepath) - - -def join_path( - filepath: Union[str, Path], - *filepaths: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - *filepaths (str or Path): Other paths to be concatenated. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - str: The result of concatenation. - - Examples: - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.join_path(filepath, *filepaths) - - -@contextmanager -def get_local_path( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Generator[Union[str, Path], None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself and it will - not be released (removed). - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: Only yield one path. - - Examples: - >>> with get_local_path('abc/def.jpg') as path: - ... # do something here - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - with backend.get_local_path(str(filepath)) as local_path: - yield local_path - - -def copyfile( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError will - be raised. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> copyfile(src, dst) - '/path1/of/dir/file' - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copyfile(src, dst) - - -def copytree( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will be - raised. - - Examples: - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> copytree(src, dst) - '/path/of/dir2' - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copytree(src, dst) - - -def copyfile_from_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Copy a local file src to dst and return the destination file. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = 's3://openmmlab/mmengine/file1' - >>> # src will be copied to 's3://openmmlab/mmengine/file1' - >>> copyfile_from_local(src, dst) - s3://openmmlab/mmengine/file1 - - >>> # dst is a directory - >>> dst = 's3://openmmlab/mmengine' - >>> # src will be copied to 's3://openmmlab/mmengine/file'' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/file' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copyfile_from_local(src, dst) - - -def copytree_from_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = '/path/of/dir' - >>> dst = 's3://openmmlab/mmengine/dir' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/dir' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copytree_from_local(src, dst) - - -def copyfile_to_local( - src: Union[str, Path], - dst: Union[str, Path], - dst_type: str, # Choose from ["file", "dir"] - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Copy the file src to local dst and return the destination file. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = 's3://openmmlab/mmengine/file' - >>> dst = '/path/of/file' - >>> # src will be copied to '/path/of/file' - >>> copyfile_to_local(src, dst) - '/path/of/file' - - >>> # dst is a directory - >>> dst = '/path/of/dir' - >>> # src will be copied to '/path/of/dir/file' - >>> copyfile_to_local(src, dst) - '/path/of/dir/file' - """ - assert dst_type in ["file", "dir"] - Path(dst).parent.mkdir(parents=True, exist_ok=True) - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copyfile_to_local(src, dst, dst_type=dst_type) - - -def copytree_to_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = 's3://openmmlab/mmengine/dir' - >>> dst = '/path/of/dir' - >>> copytree_to_local(src, dst) - '/path/of/dir' - """ - Path(dst).parent.mkdir(parents=True, exist_ok=True) - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copytree_to_local(src, dst) - - -def remove( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Raises: - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - - Examples: - >>> filepath = '/path/of/file' - >>> remove(filepath) - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.remove(filepath) - - -def rmtree( - dir_path: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> dir_path = '/path/of/dir' - >>> rmtree(dir_path) - """ - backend = get_file_backend( - dir_path, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.rmtree(dir_path) - - -def copy_if_symlink_fails( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directory copy src to - dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return True if successfully create a symbolic link pointing to - src. Otherwise, return False. - - Examples: - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> copy_if_symlink_fails(src, dst) - True - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copy_if_symlink_fails(src, dst) - - -def list_dir( - dir_path: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -): - """List all folders in an S3 bucket with a given prefix. - - Args: - dir_path (str | Path): Path of the directory. - - Examples: - >>> dir_path = '/path/of/dir' - >>> for file_path in list_dir(dir_path): - ... print(file_path) - """ - if not dir_path.endswith("/"): - dir_path += "/" - backend = get_file_backend( - dir_path, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - - return backend.list_dir(dir_path) - - -def list_dir_or_file( - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> dir_path = '/path/of/dir' - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # list those files and directories in current directory - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ - backend = get_file_backend( - dir_path, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) - - -def generate_presigned_url( - url: str, - client_method: str = "get_object", - expires_in: int = 3600, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> str: - """Generate the presigned url of video stream which can be passed to - mmcv.VideoReader. Now only work on s3 backend. - - Note: - Now only work on s3 backend. - - Args: - url (str): Url of video stream. - client_method (str): Method of client, 'get_object' or - 'put_object'. Defaults to 'get_object'. - expires_in (int): expires, in seconds. Defaults to 3600. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: Generated presigned url. - """ - backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.generate_presigned_url(url, client_method, expires_in) - - -def load( - file: Union[str, Path, IO[Any]], - file_format: Optional[str] = None, - file_client_args: Optional[dict] = None, - fast_backend: bool = False, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, - **kwargs, -): - """Load data from json/yaml/pickle files. - - This method provides a unified api for loading data from serialized files. - - ``load`` supports loading data from serialized files those can be storaged - in different backends. - - Args: - file (str or :obj:`Path` or file-like object): Filename or a file-like - object. - file_format (str, optional): If not specified, the file format will be - inferred from the file extension, otherwise use the specified one. - Currently supported formats include "json", "yaml/yml" and - "pickle/pkl". - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - fast_backend: bool: Whether to use multiprocess. Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> load('/path/of/your/file') # file is storaged in disk - >>> load('https://path/of/your/file') # file is storaged in Internet - >>> load('s3://path/of/your/file') # file is storaged in s3 - - Returns: - The content from the file. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None and isinstance(file, str): - file_format = file.split(".")[-1] - # convert file_format to lower case - file_format = file_format.lower() - if file_format not in file_handlers: - raise TypeError(f"Unsupported format: {file_format}") - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', - DeprecationWarning, - ) - if backend_args is not None: - raise ValueError('"file_client_args and "backend_args" cannot be set at the ' "same time.") - - handler = file_handlers[file_format] - if isinstance(file, str): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend( - file, - backend_args=backend_args, - backend_key=backend_key, - enable_singleton=True, - ) - - if handler.str_like: - with StringIO(file_backend.get_text(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - if fast_backend: - if hasattr(file_backend, "fast_get"): - with BytesIO(file_backend.fast_get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - warnings.warn( - f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get" - ) - with BytesIO(file_backend.get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - with BytesIO(file_backend.get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - elif hasattr(file, "read"): - obj = handler.load_from_fileobj(file, **kwargs) - else: - raise TypeError('"file" must be a filepath str or a file-object') - return obj - - -def dump( - obj: Any, - file: Union[str, Path, IO[Any], None] = None, - file_format: Optional[str] = None, - file_client_args: Optional[dict] = None, - fast_backend: bool = False, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, - **kwargs, -): - """Dump data to json/yaml/pickle strings or files. - - This method provides a unified api for dumping data as strings or to files, - and also supports custom arguments for each file format. - - ``dump`` supports dumping data as strings or to files which is saved to - different backends. - - Args: - obj (any): The python object to be dumped. - file (str or :obj:`Path` or file-like object, optional): If not - specified, then the object is dumped to a str, otherwise to a file - specified by the filename or file-like object. - file_format (str, optional): Same as :func:`load`. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - fast_backend: bool: Whether to use multiprocess. Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - backend_key: str: The key to register the backend. Defaults to None. - - Examples: - >>> dump('hello world', '/path/of/your/file') # disk - >>> dump('hello world', 's3://path/of/your/file') # ceph or s3 - - Returns: - bool: True for success, False otherwise. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None: - if isinstance(file, str): - file_format = file.split(".")[-1] - elif file is None: - raise ValueError("file_format must be specified since file is None") - # convert file_format to lower case - file_format = file_format.lower() - if file_format not in file_handlers: - raise TypeError(f"Unsupported format: {file_format}") - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. ' 'Please use "backend_args" instead', - DeprecationWarning, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the ' "same time.") - - handler = file_handlers[file_format] - if file is None: - return handler.dump_to_str(obj, **kwargs) - elif isinstance(file, str): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend( - file, - backend_args=backend_args, - backend_key=backend_key, - enable_singleton=True, - ) - - if handler.str_like: - with StringIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - file_backend.put_text(f.getvalue(), file) - else: - with BytesIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - if fast_backend: - if hasattr(file_backend, "fast_put"): - file_backend.fast_put(f, file) - else: - warnings.warn("fast_backend is not supported by the backend, fallback to normal put") - file_backend.put(f, file) - else: - file_backend.put(f, file) - elif hasattr(file, "write"): - handler.dump_to_fileobj(obj, file, **kwargs) - else: - raise TypeError('"file" must be a filename str or a file-object') diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/file_client.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/file_client.py deleted file mode 100644 index be8a378ecdefd6129ada0194bc7da1fd66a701e4..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/file_client.py +++ /dev/null @@ -1,450 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Generator, Iterator, Optional, Tuple, Union - -from cosmos_transfer1.utils.easy_io.backends import BaseStorageBackend, HTTPBackend, LocalBackend - - -def is_filepath(filepath): - return isinstance(filepath, (str, Path)) - - -class HardDiskBackend(LocalBackend): - """Raw hard disks storage backend.""" - - @property - def name(self): - return self.__class__.__name__ - - -class FileClient: - """A general file client to access files in different backends. - - The client loads a file or text in a specified backend from its path - and returns it as a binary or text file. There are two ways to choose a - backend, the name of backend and the prefix of path. Although both of them - can be used to choose a storage backend, ``backend`` has a higher priority - that is if they are all set, the storage backend will be chosen by the - backend argument. If they are all `None`, the disk backend will be chosen. - Note that It can also register other backend accessor with a given name, - prefixes, and backend class. In addition, We use the singleton pattern to - avoid repeated object creation. If the arguments are the same, the same - object will be returned. - - Warning: - `FileClient` will be deprecated in future. Please use io functions - in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io - - Args: - backend (str, optional): The storage backend type. Options are "disk", - "memcached", "lmdb", "http" and "s3". Defaults to None. - prefix (str, optional): The prefix of the registered storage backend. - Options are "s3", "http", "https". Defaults to None. - - Examples: - >>> # only set backend - >>> file_client = FileClient(backend='s3') - >>> # only set prefix - >>> file_client = FileClient(prefix='s3') - >>> # set both backend and prefix but use backend to choose client - >>> file_client = FileClient(backend='s3', prefix='s3') - >>> # if the arguments are the same, the same object is returned - >>> file_client1 = FileClient(backend='s3') - >>> file_client1 is file_client - True - - Attributes: - client (:obj:`BaseStorageBackend`): The backend object. - """ - - _backends = { - "disk": HardDiskBackend, - "http": HTTPBackend, - } - - _prefix_to_backends: dict = { - "http": HTTPBackend, - "https": HTTPBackend, - } - - _instances: dict = {} - - client: Any - - def __new__(cls, backend=None, prefix=None, **kwargs): - if backend is None and prefix is None: - backend = "disk" - if backend is not None and backend not in cls._backends: - raise ValueError( - f"Backend {backend} is not supported. Currently supported ones" f" are {list(cls._backends.keys())}" - ) - if prefix is not None and prefix not in cls._prefix_to_backends: - raise ValueError( - f"prefix {prefix} is not supported. Currently supported ones " - f"are {list(cls._prefix_to_backends.keys())}" - ) - - # concatenate the arguments to a unique key for determining whether - # objects with the same arguments were created - arg_key = f"{backend}:{prefix}" - for key, value in kwargs.items(): - arg_key += f":{key}:{value}" - - # if a backend was overridden, it will create a new object - if arg_key in cls._instances: - _instance = cls._instances[arg_key] - else: - # create a new object and put it to _instance - _instance = super().__new__(cls) - if backend is not None: - _instance.client = cls._backends[backend](**kwargs) - else: - _instance.client = cls._prefix_to_backends[prefix](**kwargs) - - cls._instances[arg_key] = _instance - - return _instance - - @property - def name(self): - return self.client.name - - @property - def allow_symlink(self): - return self.client.allow_symlink - - @staticmethod - def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: - """Parse the prefix of a uri. - - Args: - uri (str | Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> FileClient.parse_uri_prefix('s3://path/of/your/file') - 's3' - - Returns: - str | None: Return the prefix of uri if the uri contains '://' else - ``None``. - """ - assert is_filepath(uri) - uri = str(uri) - if "://" not in uri: - return None - else: - prefix, _ = uri.split("://") - if ":" in prefix: - _, prefix = prefix.split(":") - return prefix - - @classmethod - def infer_client( - cls, - file_client_args: Optional[dict] = None, - uri: Optional[Union[str, Path]] = None, - ) -> "FileClient": - """Infer a suitable file client based on the URI and arguments. - - Args: - file_client_args (dict, optional): Arguments to instantiate a - FileClient. Defaults to None. - uri (str | Path, optional): Uri to be parsed that contains the file - prefix. Defaults to None. - - Examples: - >>> uri = 's3://path/of/your/file' - >>> file_client = FileClient.infer_client(uri=uri) - >>> file_client_args = {'backend': 's3'} - >>> file_client = FileClient.infer_client(file_client_args) - - Returns: - FileClient: Instantiated FileClient object. - """ - assert file_client_args is not None or uri is not None - if file_client_args is None: - file_prefix = cls.parse_uri_prefix(uri) # type: ignore - return cls(prefix=file_prefix) - else: - return cls(**file_client_args) - - @classmethod - def _register_backend(cls, name, backend, force=False, prefixes=None): - if not isinstance(name, str): - raise TypeError("the backend name should be a string, " f"but got {type(name)}") - if not inspect.isclass(backend): - raise TypeError(f"backend should be a class but got {type(backend)}") - if not issubclass(backend, BaseStorageBackend): - raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") - if not force and name in cls._backends: - raise KeyError( - f"{name} is already registered as a storage backend, " 'add "force=True" if you want to override it' - ) - - if name in cls._backends and force: - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, cls._backends[name]): - cls._instances.pop(arg_key) - cls._backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - for prefix in prefixes: - if prefix not in cls._prefix_to_backends: - cls._prefix_to_backends[prefix] = backend - elif (prefix in cls._prefix_to_backends) and force: - overridden_backend = cls._prefix_to_backends[prefix] - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, overridden_backend): - cls._instances.pop(arg_key) - else: - raise KeyError( - f"{prefix} is already registered as a storage backend," - ' add "force=True" if you want to override it' - ) - - @classmethod - def register_backend(cls, name, backend=None, force=False, prefixes=None): - """Register a backend to FileClient. - - This method can be used as a normal class method or a decorator. - - .. code-block:: python - - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - FileClient.register_backend('new', NewBackend) - - or - - .. code-block:: python - - @FileClient.register_backend('new') - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool, optional): Whether to override the backend if the name - has already been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefixes - of the registered storage backend. Defaults to None. - `New in version 1.3.15.` - """ - if backend is not None: - cls._register_backend(name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - cls._register_backend(name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: - """Read data from a given ``filepath`` with 'rb' mode. - - Note: - There are two types of return values for ``get``, one is ``bytes`` - and the other is ``memoryview``. The advantage of using memoryview - is that you can avoid copying, and if you want to convert it to - ``bytes``, you can use ``.tobytes()``. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes | memoryview: Expected bytes object or a memory view of the - bytes object. - """ - return self.client.get(filepath) - - def get_text(self, filepath: Union[str, Path], encoding="utf-8") -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - return self.client.get_text(filepath, encoding) - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of ``filepath`` - does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - """ - self.client.put(obj, filepath) - - def put_text(self, obj: str, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - `filepath`. Defaults to 'utf-8'. - """ - self.client.put_text(obj, filepath) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - """ - self.client.remove(filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - return self.client.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - return self.client.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - return self.client.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - """ - return self.client.join_path(filepath, *filepaths) - - @contextmanager - def get_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself. - - .. warning:: - ``get_local_path`` is an experimental interface that may change in - the future. - - Args: - filepath (str or Path): Path to be read data. - - Examples: - >>> file_client = FileClient(prefix='s3') - >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - - Yields: - Iterable[str]: Only yield one path. - """ - with self.client.get_local_path(str(filepath)) as local_path: - yield local_path - - def list_dir_or_file( # pylint: disable=too-many-arguments - self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the - directory. Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/__init__.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/__init__.py deleted file mode 100644 index aafac064bda0c10325e6745cf2edd926c1579ddb..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler -from cosmos_transfer1.utils.easy_io.handlers.json_handler import JsonHandler -from cosmos_transfer1.utils.easy_io.handlers.pickle_handler import PickleHandler -from cosmos_transfer1.utils.easy_io.handlers.registry_utils import file_handlers, register_handler -from cosmos_transfer1.utils.easy_io.handlers.yaml_handler import YamlHandler - -__all__ = [ - "BaseFileHandler", - "JsonHandler", - "PickleHandler", - "YamlHandler", - "register_handler", - "file_handlers", -] diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/base.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/base.py deleted file mode 100644 index 5e5dcbcabc40807706eeb43d1a598571c51922a8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/base.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABCMeta, abstractmethod - - -class BaseFileHandler(metaclass=ABCMeta): - # `str_like` is a flag to indicate whether the type of file object is - # str-like object or bytes-like object. Pickle only processes bytes-like - # objects but json only processes str-like object. If it is str-like - # object, `StringIO` will be used to process the buffer. - str_like = True - - @abstractmethod - def load_from_fileobj(self, file, **kwargs): - pass - - @abstractmethod - def dump_to_fileobj(self, obj, file, **kwargs): - pass - - @abstractmethod - def dump_to_str(self, obj, **kwargs): - pass - - def load_from_path(self, filepath, mode="r", **kwargs): - with open(filepath, mode) as f: - return self.load_from_fileobj(f, **kwargs) - - def dump_to_path(self, obj, filepath, mode="w", **kwargs): - with open(filepath, mode) as f: - self.dump_to_fileobj(obj, f, **kwargs) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/csv_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/csv_handler.py deleted file mode 100644 index c76294ca235000e522d8f8fdce1913faeab8f1f8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/csv_handler.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import csv -from io import StringIO - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -class CsvHandler(BaseFileHandler): - def load_from_fileobj(self, file, **kwargs): - del kwargs - reader = csv.reader(file) - return list(reader) - - def dump_to_fileobj(self, obj, file, **kwargs): - del kwargs - writer = csv.writer(file) - if not all(isinstance(row, list) for row in obj): - raise ValueError("Each row must be a list") - writer.writerows(obj) - - def dump_to_str(self, obj, **kwargs): - del kwargs - output = StringIO() - writer = csv.writer(output) - if not all(isinstance(row, list) for row in obj): - raise ValueError("Each row must be a list") - writer.writerows(obj) - return output.getvalue() diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/gzip_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/gzip_handler.py deleted file mode 100644 index 2e063a73dbe79c91ebe8ce037c855cb20ae2110b..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/gzip_handler.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gzip -import pickle -from io import BytesIO -from typing import Any - -from cosmos_transfer1.utils.easy_io.handlers.pickle_handler import PickleHandler - - -class GzipHandler(PickleHandler): - str_like = False - - def load_from_fileobj(self, file: BytesIO, **kwargs): - with gzip.GzipFile(fileobj=file, mode="rb") as f: - return pickle.load(f) - - def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): - with gzip.GzipFile(fileobj=file, mode="wb") as f: - pickle.dump(obj, f) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/imageio_video_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/imageio_video_handler.py deleted file mode 100644 index 67dbbc27616191b7940dbe6ab89244562de7bf37..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/imageio_video_handler.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import IO - -import numpy as np -import torch - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - -try: - import imageio -except ImportError: - imageio = None - - -class ImageioVideoHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs): - """ - Load video from a file-like object using imageio with specified format and color mode. - - Parameters: - file (IO[bytes]): A file-like object containing video data. - format (str): Format of the video file (default 'mp4'). - mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). - - Returns: - tuple: A tuple containing an array of video frames and metadata about the video. - """ - file.seek(0) - video_reader = imageio.get_reader(file, format, **kwargs) - - video_frames = [] - for frame in video_reader: - if mode == "gray": - import cv2 # Convert frame to grayscale if mode is gray - - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent - video_frames.append(frame) - - return np.array(video_frames), video_reader.get_meta_data() - - def dump_to_fileobj( - self, - obj: np.ndarray | torch.Tensor, - file: IO[bytes], - format: str = "mp4", # pylint: disable=redefined-builtin - fps: int = 17, - quality: int = 5, - **kwargs, - ): - """ - Save an array of video frames to a file-like object using imageio. - - Parameters: - obj (np.ndarray): An array of frames to be saved as video. - file (IO[bytes]): A file-like object to which the video data will be written. - format (str): Format of the video file (default 'mp4'). - fps (int): Frames per second of the output video (default 30). - - """ - if isinstance(obj, torch.Tensor): - assert obj.dtype == torch.uint8 - obj = obj.cpu().numpy() - h, w = obj.shape[1:-1] - kwargs = { - "fps": fps, - "quality": quality, - "macro_block_size": 1, - "ffmpeg_params": ["-s", f"{w}x{h}"], - "output_params": ["-f", "mp4"], - } - imageio.mimsave(file, obj, format, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/json_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/json_handler.py deleted file mode 100644 index beb55c61c93549b856d5825f244ba7a782e2acf1..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/json_handler.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import numpy as np - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -def set_default(obj): - """Set default json values for non-serializable values. - - It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. - It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, - etc.) into plain numbers of plain python built-in types. - """ - if isinstance(obj, (set, range)): - return list(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.generic): - return obj.item() - raise TypeError(f"{type(obj)} is unsupported for json dump") - - -class JsonHandler(BaseFileHandler): - def load_from_fileobj(self, file): - return json.load(file) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault("default", set_default) - json.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("default", set_default) - return json.dumps(obj, **kwargs) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/jsonl_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/jsonl_handler.py deleted file mode 100644 index 000ffb1966d1cba52d7ca444a2657463b2a9a2cc..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/jsonl_handler.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from typing import IO - -import numpy as np - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -def set_default(obj): - """Set default json values for non-serializable values. - - It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. - It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, - etc.) into plain numbers of plain python built-in types. - """ - if isinstance(obj, (set, range)): - return list(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.generic): - return obj.item() - raise TypeError(f"{type(obj)} is unsupported for json dump") - - -class JsonlHandler(BaseFileHandler): - """Handler for JSON lines (JSONL) files.""" - - def load_from_fileobj(self, file: IO[bytes]): - """Load JSON objects from a newline-delimited JSON (JSONL) file object. - - Returns: - A list of Python objects loaded from each JSON line. - """ - data = [] - for line in file: - line = line.strip() - if not line: - continue # skip empty lines if any - data.append(json.loads(line)) - return data - - def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs): - """Dump a list of objects to a newline-delimited JSON (JSONL) file object. - - Args: - obj: A list (or iterable) of objects to dump line by line. - """ - kwargs.setdefault("default", set_default) - for item in obj: - file.write(json.dumps(item, **kwargs) + "\n") - - def dump_to_str(self, obj, **kwargs): - """Dump a list of objects to a newline-delimited JSON (JSONL) string.""" - kwargs.setdefault("default", set_default) - lines = [json.dumps(item, **kwargs) for item in obj] - return "\n".join(lines) - - -if __name__ == "__main__": - from cosmos_transfer1.utils.easy_io import easy_io - - easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl") - print(easy_io.load("test.jsonl")) - easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl") - print(easy_io.load("test.jsonl")) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/np_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/np_handler.py deleted file mode 100644 index 070396a1310ad892866ed4540a24738e2d8a0e2f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/np_handler.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from io import BytesIO -from typing import IO, Any - -import numpy as np - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -class NumpyHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any: - """ - Load a NumPy array from a file-like object. - - Parameters: - file (IO[bytes]): The file-like object containing the NumPy array data. - **kwargs: Additional keyword arguments passed to `np.load`. - - Returns: - numpy.ndarray: The loaded NumPy array. - """ - return np.load(file, **kwargs) - - def load_from_path(self, filepath: str, **kwargs) -> Any: - """ - Load a NumPy array from a file path. - - Parameters: - filepath (str): The path to the file to load. - **kwargs: Additional keyword arguments passed to `np.load`. - - Returns: - numpy.ndarray: The loaded NumPy array. - """ - return super().load_from_path(filepath, mode="rb", **kwargs) - - def dump_to_str(self, obj: np.ndarray, **kwargs) -> str: - """ - Serialize a NumPy array to a string in binary format. - - Parameters: - obj (np.ndarray): The NumPy array to serialize. - **kwargs: Additional keyword arguments passed to `np.save`. - - Returns: - str: The serialized NumPy array as a string. - """ - with BytesIO() as f: - np.save(f, obj, **kwargs) - return f.getvalue() - - def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs): - """ - Dump a NumPy array to a file-like object. - - Parameters: - obj (np.ndarray): The NumPy array to dump. - file (IO[bytes]): The file-like object to which the array is dumped. - **kwargs: Additional keyword arguments passed to `np.save`. - """ - np.save(file, obj, **kwargs) - - def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs): - """ - Dump a NumPy array to a file path. - - Parameters: - obj (np.ndarray): The NumPy array to dump. - filepath (str): The file path where the array should be saved. - **kwargs: Additional keyword arguments passed to `np.save`. - """ - with open(filepath, "wb") as f: - np.save(f, obj, **kwargs) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pandas_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pandas_handler.py deleted file mode 100644 index 3389cfc88913ecc45e635b138420c74a5928e1e7..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pandas_handler.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pandas as pd - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip - - -class PandasHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return pd.read_csv(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - obj.to_csv(file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError("PandasHandler does not support dumping to str") diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pickle_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pickle_handler.py deleted file mode 100644 index 618e750aac7c299f64a8fba5e116a73116ad614f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pickle_handler.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle -from io import BytesIO -from typing import Any - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -class PickleHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file: BytesIO, **kwargs): - return pickle.load(file, **kwargs) - - def load_from_path(self, filepath, **kwargs): - return super().load_from_path(filepath, mode="rb", **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("protocol", 2) - return pickle.dumps(obj, **kwargs) - - def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): - kwargs.setdefault("protocol", 2) - pickle.dump(obj, file, **kwargs) - - def dump_to_path(self, obj, filepath, **kwargs): - with open(filepath, "wb") as f: - pickle.dump(obj, f, **kwargs) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pil_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pil_handler.py deleted file mode 100644 index 618ca9d22bb42d8322cc8578b4a932e1f2a99246..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/pil_handler.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import IO, Optional, Tuple, Union - -import numpy as np - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - -try: - from PIL import Image -except ImportError: - Image = None - - -class PILHandler(BaseFileHandler): - format: str - str_like = False - - def load_from_fileobj( - self, - file: IO[bytes], - fmt: str = "pil", - size: Optional[Union[int, Tuple[int, int]]] = None, - **kwargs, - ): - """ - Load an image from a file-like object and return it in a specified format. - - Args: - file (IO[bytes]): A file-like object containing the image data. - fmt (str): The format to convert the image into. Options are \ - 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \ - 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor). - size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \ - or a tuple of (width, height). If specified, the image is resized accordingly. - **kwargs: Additional keyword arguments that can be passed to conversion functions. - - Returns: - Image data in the format specified by `fmt`. - - Raises: - IOError: If the image cannot be loaded or processed. - ValueError: If the specified format is unsupported. - """ - try: - img = Image.open(file) - img.load() # Explicitly load the image data - if size is not None: - if isinstance(size, int): - size = ( - size, - size, - ) # create a tuple if only one integer is provided - img = img.resize(size, Image.ANTIALIAS) - - # Return the image in the requested format - if fmt in ["numpy", "np", "npy"]: - return np.array(img, **kwargs) - if fmt == "pil": - return img - if fmt in ["th", "torch"]: - import torch - - # Convert to tensor - img_tensor = torch.from_numpy(np.array(img, **kwargs)) - # Convert image from HxWxC to CxHxW - if img_tensor.ndim == 3: - img_tensor = img_tensor.permute(2, 0, 1) - return img_tensor - raise ValueError( - "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'." - ) - except Exception as e: - raise IOError(f"Unable to load image: {e}") from e - - def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs): - if "format" not in kwargs: - kwargs["format"] = self.format - kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper() - obj.save(file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/registry_utils.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/registry_utils.py deleted file mode 100644 index 286d0a6e964964d53000d697be95f864fc46b642..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/registry_utils.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler -from cosmos_transfer1.utils.easy_io.handlers.csv_handler import CsvHandler -from cosmos_transfer1.utils.easy_io.handlers.gzip_handler import GzipHandler -from cosmos_transfer1.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler -from cosmos_transfer1.utils.easy_io.handlers.json_handler import JsonHandler -from cosmos_transfer1.utils.easy_io.handlers.jsonl_handler import JsonlHandler -from cosmos_transfer1.utils.easy_io.handlers.np_handler import NumpyHandler -from cosmos_transfer1.utils.easy_io.handlers.pandas_handler import PandasHandler -from cosmos_transfer1.utils.easy_io.handlers.pickle_handler import PickleHandler -from cosmos_transfer1.utils.easy_io.handlers.pil_handler import PILHandler -from cosmos_transfer1.utils.easy_io.handlers.tarfile_handler import TarHandler -from cosmos_transfer1.utils.easy_io.handlers.torch_handler import TorchHandler -from cosmos_transfer1.utils.easy_io.handlers.torchjit_handler import TorchJitHandler -from cosmos_transfer1.utils.easy_io.handlers.txt_handler import TxtHandler -from cosmos_transfer1.utils.easy_io.handlers.yaml_handler import YamlHandler - -file_handlers = { - "json": JsonHandler(), - "yaml": YamlHandler(), - "yml": YamlHandler(), - "pickle": PickleHandler(), - "pkl": PickleHandler(), - "tar": TarHandler(), - "jit": TorchJitHandler(), - "npy": NumpyHandler(), - "txt": TxtHandler(), - "csv": CsvHandler(), - "pandas": PandasHandler(), - "gz": GzipHandler(), - "jsonl": JsonlHandler(), -} - -for torch_type in ["pt", "pth", "ckpt"]: - file_handlers[torch_type] = TorchHandler() -for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]: - file_handlers[img_type] = PILHandler() - file_handlers[img_type].format = img_type -for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]: - file_handlers[video_type] = ImageioVideoHandler() - - -def _register_handler(handler, file_formats): - """Register a handler for some file extensions. - - Args: - handler (:obj:`BaseFileHandler`): Handler to be registered. - file_formats (str or list[str]): File formats to be handled by this - handler. - """ - if not isinstance(handler, BaseFileHandler): - raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") - if isinstance(file_formats, str): - file_formats = [file_formats] - if not all([isinstance(item, str) for item in file_formats]): - raise TypeError("file_formats must be a str or a list of str") - for ext in file_formats: - file_handlers[ext] = handler - - -def register_handler(file_formats, **kwargs): - def wrap(cls): - _register_handler(cls(**kwargs), file_formats) - return cls - - return wrap diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/tarfile_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/tarfile_handler.py deleted file mode 100644 index 9992569deac77ce4a82ffe278c3b4d428b0f26cc..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/tarfile_handler.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tarfile - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -class TarHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, mode="r|*", **kwargs): - return tarfile.open(fileobj=file, mode=mode, **kwargs) - - def load_from_path(self, filepath, mode="r|*", **kwargs): - return tarfile.open(filepath, mode=mode, **kwargs) - - def dump_to_fileobj(self, obj, file, mode="w", **kwargs): - with tarfile.open(fileobj=file, mode=mode) as tar: - tar.add(obj, **kwargs) - - def dump_to_path(self, obj, filepath, mode="w", **kwargs): - with tarfile.open(filepath, mode=mode) as tar: - tar.add(obj, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/torch_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/torch_handler.py deleted file mode 100644 index 71adc6e199d566cc82d59745c3936a3b56b56a22..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/torch_handler.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - import torch -except ImportError: - torch = None - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -class TorchHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return torch.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - torch.save(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/torchjit_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/torchjit_handler.py deleted file mode 100644 index 6711cddfea635fc9a13d1a240c141ae82fdf64d2..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/torchjit_handler.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - import torch -except ImportError: - torch = None - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -class TorchJitHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return torch.jit.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - torch.jit.save(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/txt_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/txt_handler.py deleted file mode 100644 index d42408d7432d061b4bd07f3a15833d0020bdbc0c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/txt_handler.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler - - -class TxtHandler(BaseFileHandler): - def load_from_fileobj(self, file, **kwargs): - del kwargs - return file.read() - - def dump_to_fileobj(self, obj, file, **kwargs): - del kwargs - if not isinstance(obj, str): - obj = str(obj) - file.write(obj) - - def dump_to_str(self, obj, **kwargs): - del kwargs - if not isinstance(obj, str): - obj = str(obj) - return obj diff --git a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/yaml_handler.py b/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/yaml_handler.py deleted file mode 100644 index 246c123a549357054cdee24bc34f27fefb2bfb5c..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/easy_io/handlers/yaml_handler.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import yaml - -try: - from yaml import CDumper as Dumper # type: ignore - from yaml import CLoader as Loader # type: ignore -except ImportError: - from yaml import Loader, Dumper # type: ignore - -from cosmos_transfer1.utils.easy_io.handlers.base import BaseFileHandler # isort:skip - - -class YamlHandler(BaseFileHandler): - def load_from_fileobj(self, file, **kwargs): - kwargs.setdefault("Loader", Loader) - return yaml.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault("Dumper", Dumper) - yaml.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("Dumper", Dumper) - return yaml.dump(obj, **kwargs) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/ema.py b/cosmos-transfer1/cosmos_transfer1/utils/ema.py deleted file mode 100644 index e402f65f26290a1f991d03b0e107fe955ca312ea..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/ema.py +++ /dev/null @@ -1,327 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union - -import numpy as np -import torch -from megatron.core import parallel_state - -from cosmos_transfer1.utils import distributed, log - -if TYPE_CHECKING: - from cosmos_transfer1.utils.model import Model - - -class FastEmaModelUpdater: - """ - This class is used to update target model~(EMA) given source model~(regular model) and beta. - The method interaface mimic :class:`EMAModelTracker` and :class:`PowerEMATracker`. - Different from two classes, this class does not maintain the EMA model weights as buffers. It expects the user to have two module with same architecture and weights shape. - The class is proposed to work with FSDP model where above two classes are not working as expected. Besides, it is strange to claim model weights as buffers and do unnecessary name changing in :class:`EMAModelTracker` and :class:`PowerEMATracker`. Moeving forward, we should use this class instead of above two classes. - """ - - def __init__(self): - # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite - self.is_cached = False - - def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None: - target_list = [] - source_list = [] - for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): - assert ( - tgt_params.dtype == torch.float32 - ), f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead." - target_list.append(tgt_params) - source_list.append(src_params.data) - torch._foreach_mul_(target_list, beta) - torch._foreach_add_(target_list, source_list, alpha=1.0 - beta) - - def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None: - for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): - tgt_params.data.copy_(src_params.data) - - def cache(self, parameters: Any, is_cpu: bool = False) -> None: - """Save the current parameters for restoring later. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. - """ - assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" - device = "cpu" if is_cpu else "cuda" - self.collected_params = [param.clone().to(device) for param in parameters] - self.is_cached = True - - def restore(self, parameters: Any) -> None: - """Restore the parameters in self.collected_params. - - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before copy_to(). - After validation (or model saving), use this to restore the former parameters. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. - """ - assert self.is_cached, "EMA cache is not taken yet." - for c_param, param in zip(self.collected_params, parameters, strict=False): - param.data.copy_(c_param.data.type_as(param.data)) - self.collected_params = [] - # Release the cache after we call restore - self.is_cached = False - - -def get_buffer_name(param_name: str, torch_compile_buffer_renaming: bool = False) -> str: - """ - This function creates buffer name used by EMA from parameter's name - - Args: - param_name (str): Model's parameter name - Returns: - buffer_name (str): buffer name to be used for given parameter name - """ - - buffer_name = param_name.replace(".", "-") - - if torch_compile_buffer_renaming: - # torch.compile() adds _orig_mod to state dict names, this way we get original name - buffer_name = buffer_name.replace("_orig_mod-", "") - - return buffer_name - - -class EMAModelTracker(torch.nn.Module): - """This is a class to track the EMA model weights. - - The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the - regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's - implementation of EMA. There are no optimizable parameters. - - Attributes: - collected_params (list): temporarily stores the regular weights while in EMA mode. - beta (float): EMA decay rate. (default: 0.9999). - torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used - """ - - def __init__(self, model: Model, beta: float = 0.9999, torch_compile_buffer_renaming: bool = False): - """Constructor of the EMA model weight tracker. - - Args: - model (Model): The PyTorch model. - beta (float): EMA decay rate. (default: 0.9999). - """ - super().__init__() - self.torch_compile_buffer_renaming: bool = torch_compile_buffer_renaming - if not 0.0 <= beta <= 1.0: - raise ValueError("Decay must be between 0 and 1") - self.beta = beta - for name, param in model.named_parameters(): - if param.requires_grad: - buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) - self.register_buffer(buffer_name, param.clone().detach().data) - self.collected_params = [] - # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite - self.is_cached = False - - @torch.no_grad() - def update_average(self, model: Model, iteration: Optional[int] = None) -> None: - del iteration - target_list = [] - source_list = [] - ema_buffers = self.state_dict() - for name, param in model.named_parameters(): - if param.requires_grad: - buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) - buffer = ema_buffers[buffer_name] - assert buffer.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {buffer.dtype} instead." - target_list.append(buffer) - source_list.append(param.data) - torch._foreach_mul_(target_list, self.beta) - torch._foreach_add_(target_list, source_list, alpha=1.0 - self.beta) - - def copy_to(self, model: Model) -> None: - ema_buffers = self.state_dict() - for name, param in model.named_parameters(): - if param.requires_grad: - buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) - buffer = ema_buffers[buffer_name] - param.data.copy_(buffer.data) - - def cache(self, parameters: Any, is_cpu: bool = False) -> None: - """Save the current parameters for restoring later. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. - """ - assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" - device = "cpu" if is_cpu else "cuda" - self.collected_params = [param.clone().to(device) for param in parameters] - self.is_cached = True - - def restore(self, parameters: Any) -> None: - """Restore the parameters in self.collected_params. - - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before copy_to(). - After validation (or model saving), use this to restore the former parameters. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. - """ - assert self.is_cached, "EMA cache is not taken yet." - for c_param, param in zip(self.collected_params, parameters, strict=False): - param.data.copy_(c_param.data.type_as(param.data)) - self.collected_params = [] - # Release the cache after we call restore - self.is_cached = False - - @classmethod - def initialize_multi_rank_ema( - cls, model: torch.nn.Module, rate: Union[float, List[float]], num: int = 1, enabled: bool = True - ) -> Optional[EMAModelTracker]: - """ - Class method to initialize per rank EMA Model Tracker with different rate. - Each rank will have a different rate based on the given configuration, resulting in different EMA weights. - - Args: - model (torch.nn.Module): The neural network model to be tracked. - rate (Union[float, List[float]]): The decay rate(s) for the EMA. If a list is provided, - it corresponds to rates for different ranks. - num (int, optional): The number of leading ranks to consider for different rates. - Defaults to 1. - enabled (bool, optional): Flag to enable or disable the creation of the tracker. - If False, returns None. Defaults to True. - - Returns: - Optional[EMAModelTracker]: An instance of EMAModelTracker if enabled, otherwise None. - - Example: - >>> model = torch.nn.Linear(10, 2) - >>> tracker = EMAModelTracker.initialize_ema_from_settings(model, rate=[0.1, 0.2], num=2) - >>> print(tracker) - - Notes: - If `rate` is a list and the current rank is less than `num`, the rate for the current rank - is used. If the current rank exceeds `num`, the first rate in the list is used by default. - """ - if not enabled: - return None - if parallel_state.is_initialized(): - cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - log.warning("It should not used together with FSDP!") - else: - cur_dp_rank = distributed.get_rank() - log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - rate = rate if isinstance(rate, list) else [rate] - num = min(num, len(rate)) - rate = rate[cur_dp_rank] if cur_dp_rank < num else rate[0] - if cur_dp_rank < num: - print(f"EMAModelTracker: rank {cur_dp_rank}, rate {rate}") - return cls(model, rate) - - -class PowerEMATracker(EMAModelTracker): - def __init__(self, model: Model, s: float = 0.1, torch_compile_buffer_renaming: bool = False): - """Constructor of the EMA model weight tracker. - - Args: - model (Model): The PyTorch model. - s (float): EMA decay rate. See EDM2 paper - torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used - """ - super().__init__(model=model, beta=0.0, torch_compile_buffer_renaming=torch_compile_buffer_renaming) - self.exp = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() - - @torch.no_grad() - def update_average(self, model: Model, iteration: Optional[int] = None) -> None: - if iteration == 0: - beta = 0.0 - else: - i = iteration + 1 - beta = (1 - 1 / i) ** (self.exp + 1) - self.beta = beta - - super().update_average(model, iteration) - - @classmethod - def initialize_multi_rank_ema( - cls, model: torch.nn.Module, rate: float, num: int, enabled: bool = True - ) -> Optional[PowerEMATracker]: - """ - Class method to initialize per rank EMA Model Tracker with different rate. - Each rank will have a different rate based on the given configuration, resulting in different EMA weights. - - Args: - model (torch.nn.Module): The neural network model for which the EMA tracker is being set up. - num (int): The number of ranks for which the rate adjustment is applied. Beyond this, the rate remains unchanged. - rate (float): The base decay rate for the EMA calculation. - enabled (bool, optional): Flag to enable or disable the initialization of the tracker. If False, returns None. - Defaults to True. - - Returns: - Optional[PowerEMATracker]: An instance of PowerEMATracker with adjusted rate if enabled, otherwise None. - - Raises: - None - - Example: - >>> model = torch.nn.Linear(10, 2) - >>> tracker = PowerEMATracker.initialize_multi_rank_ema(model, num=3, rate=0.99) - >>> print(tracker) - - Notes: - The decay rate is modified by dividing it by 2 raised to the power of the rank for each rank less than `num`. - If the rank is greater than or equal to `num`, the base rate is used without modification. This approach - allows higher ranked processes to have a less aggressive decay, potentially reflecting their delayed synchronization - in a distributed training scenario. - """ - if not enabled: - return None - if parallel_state.is_initialized(): - cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - log.warning("It should not used together with FSDP!") - else: - cur_dp_rank = distributed.get_rank() - log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - - divider = 2**cur_dp_rank if cur_dp_rank < num else 1 - if cur_dp_rank < num: - print(f"PowerEMATracker: rank {cur_dp_rank}, rate {rate / divider}") - return cls(model, rate / divider) - - -@contextmanager -def ema_scope(model: Model, enabled: bool = False) -> Generator[None, None, None]: - """Context manager for switching between regular and EMA model weights. - - Args: - model (Model): The PyTorch model. - enabled (bool): Whether switching to EMA weights is enabled (default: False). - """ - if enabled: - assert hasattr(model, "ema") and isinstance(model.ema, (FastEmaModelUpdater, EMAModelTracker, PowerEMATracker)) - model.ema.cache(model.parameters()) - model.ema.copy_to(model) - log.info("EMA: switched to EMA weights.") - try: - yield None - finally: - if enabled: - model.ema.restore(model.parameters()) - log.info("EMA: restored regular weights.") diff --git a/cosmos-transfer1/cosmos_transfer1/utils/fused_adam.py b/cosmos-transfer1/cosmos_transfer1/utils/fused_adam.py deleted file mode 100644 index 76268d932017b2f6837c1b04ee1366d5eebdd178..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/fused_adam.py +++ /dev/null @@ -1,398 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from apex.multi_tensor_apply import multi_tensor_applier - -from cosmos_transfer1.utils import distributed, log - - -class FusedAdam(torch.optim.Optimizer): - """Implements Adam algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - - This version of fused Adam implements 2 fusions. - - * Fusion of the Adam update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters - into one or a few kernel launches. - - :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, - or ``torch.optim.Adam`` with ``adam_w_mode=False``:: - - opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp, - you may choose any ``opt_level``:: - - opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - - In general, ``opt_level="O1"`` is recommended. - - - .. warning:: - A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``. - These additional arguments are now deprecated and unnecessary. - - Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) NOT SUPPORTED in FusedAdam! - adam_w_mode (boolean, optional): Apply L2 regularization or weight decay - True for decoupled weight decay(also known as AdamW) (default: True) - capturable (bool, optional): whether to use the version of the optimizer - that can be used with CUDA Graphs. (default: False) - master_weights (bool, optional): whether to maintain FP32 master weights - in the optimizer with FP16 mixed precision training, currently can - only be used with capturable set to True. (default: False) - - .. _Adam - A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - adam_w_mode=True, - weight_decay=0.0, - amsgrad=False, - capturable=False, - master_weights=False, - ): - if amsgrad: - raise RuntimeError("FusedAdam does not support the AMSGrad variant.") - if master_weights and not capturable: - raise RuntimeError("Master weights is currently only supported with the capturable version.") - # If the optimizer is capturable then LR should be a tensor (on GPU) - log.warning(f"FusedAdam master_weights: {master_weights} capturable: {capturable}") - lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr - defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) - super(FusedAdam, self).__init__(params, defaults) - self.adam_w_mode = 1 if adam_w_mode else 0 - - self.capturable = capturable - self.master_weights = master_weights - - self.param_groups_master = None - - if capturable: - for idx, group in enumerate(self.param_groups): - if len(group["params"]) == 0: - continue - device = group["params"][0].device - for item in ["lr"]: - if isinstance(group[item], float): - group[item] = torch.tensor(group[item], dtype=torch.float32) - self.param_groups[idx][item] = group[item].to(device=device) - - self._step_supports_amp_scaling = True - - if multi_tensor_applier.available: - import amp_C - - # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") - self.multi_tensor_adam = amp_C.multi_tensor_adam - self.multi_tensor_adam_capturable = amp_C.multi_tensor_adam_capturable - self.multi_tensor_adam_capturable_master = amp_C.multi_tensor_adam_capturable_master - else: - raise RuntimeError("apex.optimizers.FusedAdam requires cuda extensions") - - def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None, grad_scaler=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - - The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. - """ - if any(p is not None for p in [grads, output_params, scale, grad_norms]): - raise RuntimeError( - "FusedAdam has been updated. " - "Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." - ) - loss = None - if closure is not None: - loss = closure() - - if self.param_groups_master is None: - # Create full precision master weights - self.param_groups_master = [] - for i, pg in enumerate(self.param_groups): - param_list = pg["params"] - self.param_groups_master.append( - { - "params": [p.clone().detach().float() if self.master_weights else None for p in param_list], - } - ) - - for group, group_master in zip(self.param_groups, self.param_groups_master): - if len(group["params"]) == 0: - continue - device = group["params"][0].device - bias_correction = 1 if "bias_correction" in group and group["bias_correction"] else 0 - beta1, beta2 = group["betas"] - - # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel - if "step" in group: - if self.capturable: - group["step"] = ( - group["step"].to(device=device) - if isinstance(group["step"], torch.Tensor) - else torch.tensor(group["step"], dtype=torch.int32, device=device) - ) - group["step"] += (self._dummy_overflow_buf != 1).to(torch.int) - else: - group["step"] += 1 - else: - group["step"] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device) - - if self.capturable: - group["lr"] = ( - group["lr"].to(device=device) - if isinstance(group["lr"], torch.Tensor) - else torch.tensor(group["lr"], dtype=torch.float32, device=device) - ) - - # create lists for multi-tensor apply - g_16, p_16, m_16, v_16 = [], [], [], [] - g_bf, p_bf, m_bf, v_bf = [], [], [], [] - g_32, p_32, m_32, v_32 = [], [], [], [] - p_16_master = [] - p_32_master = [] - bf16_master = [] - - for p, p_master in zip(group["params"], group_master["params"]): - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError( - "FusedAdam does not support sparse gradients, please consider SparseAdam instead" - ) - - state = self.state[p] - # State initialization - if len(state) == 0: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p.data).float() - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p.data).float() - - if p.dtype == torch.float16: - if self.master_weights: - p_16_master.append(p_master.data) - g_16.append(p.grad.data) - p_16.append(p.data) - m_16.append(state["exp_avg"]) - v_16.append(state["exp_avg_sq"]) - elif p.dtype == torch.bfloat16: - if self.master_weights: - bf16_master.append(p_master.data) - g_bf.append(p.grad) - p_bf.append(p) - m_bf.append(state["exp_avg"]) - v_bf.append(state["exp_avg_sq"]) - elif p.dtype == torch.float32: - if self.master_weights: - p_32_master.append(p_master.data) - g_32.append(p.grad.data) - p_32.append(p.data) - m_32.append(state["exp_avg"]) - v_32.append(state["exp_avg_sq"]) - else: - raise RuntimeError("FusedAdam only support fp16 and fp32.") - - # If the optimizer is capturable, then if there's a grad scaler it works - # on the GPU + a different multi_tensor_applier should be called - if self.capturable: - # overflow check of gradients - found_inf = ( - grad_scaler._check_inf_per_device(self)[device] - if grad_scaler is not None - else torch.zeros((1,), device=device) - ) - self._dummy_overflow_buf.copy_(found_inf) - - # get unscale scale factor - scale, inv_scale = None, None - if grad_scaler: - scale = grad_scaler._get_scale_async() - inv_scale = scale.double().reciprocal().float() - else: - scale = torch.ones((1,), device=device, dtype=torch.float32) - inv_scale = torch.ones((1,), device=device, dtype=torch.float32) - - if len(g_16) > 0: - multi_tensor_applier( - ( - self.multi_tensor_adam_capturable_master - if self.master_weights - else self.multi_tensor_adam_capturable - ), - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16, p_16_master] if self.master_weights else [g_16, p_16, m_16, v_16], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - - if len(g_bf) > 0: - multi_tensor_applier( - ( - self.multi_tensor_adam_capturable_master - if self.master_weights - else self.multi_tensor_adam_capturable - ), - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf, bf16_master] if self.master_weights else [g_bf, p_bf, m_bf, v_bf], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - - if len(g_32) > 0: - multi_tensor_applier( - ( - self.multi_tensor_adam_capturable_master - if self.master_weights - else self.multi_tensor_adam_capturable - ), - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32, p_32_master] if self.master_weights else [g_32, p_32, m_32, v_32], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - inv_scale, - ) - else: - if len(g_16) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) - - if len(g_bf) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) - - if len(g_32) > 0: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group["lr"], - beta1, - beta2, - group["eps"], - group["step"], - self.adam_w_mode, - bias_correction, - group["weight_decay"], - ) - - return loss - - def load_state_dict(self, state_dict): - super().load_state_dict(state_dict) - for group in self.param_groups: - if self.capturable: - group["lr"] = ( - group["lr"].cuda() - if isinstance(group["lr"], torch.Tensor) - else torch.tensor(group["lr"], dtype=torch.float32).cuda() - ) - - if "step" in group: - if self.capturable: - if distributed.get_rank() == 0: - step = ( - group["step"].cuda() - if isinstance(group["step"], torch.Tensor) - else torch.tensor([group["step"]], dtype=torch.int32).cuda() - ) - else: - step = torch.zeros(1, dtype=torch.int32).cuda() - # make it compatible with FSDP optimizer - distributed.broadcast(step, 0) - group["step"] = step - elif isinstance(group["step"], torch.Tensor): - group["step"] = group["step"].item() - for p in group["params"]: - state = self.state[p] - if "exp_avg" in state: - state["exp_avg"] = state["exp_avg"].float() - state["exp_avg_sq"] = state["exp_avg_sq"].float() diff --git a/cosmos-transfer1/cosmos_transfer1/utils/io.py b/cosmos-transfer1/cosmos_transfer1/utils/io.py deleted file mode 100644 index c877aa41fd6b90638281f048bac23fc8214b84be..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/io.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from io import BytesIO -from typing import Dict, List - -import imageio -import numpy as np - - -def read_prompts_from_file(prompt_file: str) -> List[Dict[str, str]]: - """Read prompts from a JSONL file where each line is a dict with 'prompt' key and optionally 'visual_input' key. - - Args: - prompt_file (str): Path to JSONL file containing prompts - - Returns: - List[Dict[str, str]]: List of prompt dictionaries - """ - prompts = [] - with open(prompt_file, "r") as f: - for line in f: - prompt_dict = json.loads(line.strip()) - prompts.append(prompt_dict) - return prompts - - -def save_video(video, fps, H, W, video_save_quality, video_save_path): - """Save video frames to file. - - Args: - grid (np.ndarray): Video frames array [T,H,W,C] - fps (int): Frames per second - H (int): Frame height - W (int): Frame width - video_save_quality (int): Video encoding quality (0-10) - video_save_path (str): Output video file path - """ - kwargs = { - "fps": fps, - "quality": video_save_quality, - "macro_block_size": 1, - "ffmpeg_params": ["-s", f"{W}x{H}"], - "output_params": ["-f", "mp4"], - } - imageio.mimsave(video_save_path, video, "mp4", **kwargs) - - -def load_from_fileobj(filepath: str, format: str = "mp4", mode: str = "rgb", **kwargs): - """ - Load video from a file-like object using imageio with specified format and color mode. - - Parameters: - file (IO[bytes]): A file-like object containing video data. - format (str): Format of the video file (default 'mp4'). - mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). - - Returns: - tuple: A tuple containing an array of video frames and metadata about the video. - """ - with open(filepath, "rb") as f: - value = f.read() - with BytesIO(value) as f: - f.seek(0) - video_reader = imageio.get_reader(f, format, **kwargs) - - video_frames = [] - for frame in video_reader: - if mode == "gray": - import cv2 # Convert frame to grayscale if mode is gray - - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = np.expand_dims(frame, axis=2) # Keep frame dimensions consistent - video_frames.append(frame) - - return np.array(video_frames), video_reader.get_meta_data() diff --git a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/__init__.py b/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/__init__.py deleted file mode 100644 index dbe7c341ce9d5b7c3a53bbbb2dd268118a210dc2..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -import os - -from omegaconf import DictConfig, OmegaConf - -from cosmos_transfer1.utils.lazy_config.instantiate import instantiate -from cosmos_transfer1.utils.lazy_config.lazy import LazyCall, LazyConfig -from cosmos_transfer1.utils.lazy_config.omegaconf_patch import to_object - -OmegaConf.to_object = to_object - -PLACEHOLDER = None -LazyDict = DictConfig - -__all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"] - - -DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py - - -def fixup_module_metadata(module_name, namespace, keys=None): - """ - Fix the __qualname__ of module members to be their exported api name, so - when they are referenced in docs, sphinx can find them. Reference: - https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 - """ - if not DOC_BUILDING: - return - seen_ids = set() - - def fix_one(qualname, name, obj): - # avoid infinite recursion (relevant when using - # typing.Generic, for example) - if id(obj) in seen_ids: - return - seen_ids.add(id(obj)) - - mod = getattr(obj, "__module__", None) - if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): - obj.__module__ = module_name - # Modules, unlike everything else in Python, put fully-qualitied - # names into their __name__ attribute. We check for "." to avoid - # rewriting these. - if hasattr(obj, "__name__") and "." not in obj.__name__: - obj.__name__ = name - obj.__qualname__ = qualname - if isinstance(obj, type): - for attr_name, attr_value in obj.__dict__.items(): - fix_one(objname + "." + attr_name, attr_name, attr_value) - - if keys is None: - keys = namespace.keys() - for objname in keys: - if not objname.startswith("_"): - obj = namespace[objname] - fix_one(objname, objname, obj) - - -fixup_module_metadata(__name__, globals(), __all__) -del fixup_module_metadata diff --git a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/file_io.py b/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/file_io.py deleted file mode 100644 index d9caf0081976dd08ab6ea1c04ad53304bc51d05d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/file_io.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler -from iopath.common.file_io import PathManager as PathManagerBase - -__all__ = ["PathManager", "PathHandler"] - - -PathManager = PathManagerBase() -PathManager.register_handler(HTTPURLHandler()) -PathManager.register_handler(OneDrivePathHandler()) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/instantiate.py b/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/instantiate.py deleted file mode 100644 index 206cf6aef62c6e60a81d755c6b88f9ea7d1653c9..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/instantiate.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections.abc as abc -import dataclasses -import logging -from typing import Any - -import attrs - -from cosmos_transfer1.utils.lazy_config.registry import _convert_target_to_string, locate - -__all__ = ["dump_dataclass", "instantiate"] - - -def is_dataclass_or_attrs(target): - return dataclasses.is_dataclass(target) or attrs.has(target) - - -def dump_dataclass(obj: Any): - """ - Dump a dataclass recursively into a dict that can be later instantiated. - - Args: - obj: a dataclass object - - Returns: - dict - """ - assert dataclasses.is_dataclass(obj) and not isinstance( - obj, type - ), "dump_dataclass() requires an instance of a dataclass." - ret = {"_target_": _convert_target_to_string(type(obj))} - for f in dataclasses.fields(obj): - v = getattr(obj, f.name) - if dataclasses.is_dataclass(v): - v = dump_dataclass(v) - if isinstance(v, (list, tuple)): - v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] - ret[f.name] = v - return ret - - -def instantiate(cfg, *args, **kwargs): - """ - Recursively instantiate objects defined in dictionaries by - "_target_" and arguments. - - Args: - cfg: a dict-like object with "_target_" that defines the caller, and - other keys that define the arguments - args: Optional positional parameters pass-through. - kwargs: Optional named parameters pass-through. - - Returns: - object instantiated by cfg - """ - from omegaconf import DictConfig, ListConfig, OmegaConf - - if isinstance(cfg, ListConfig): - lst = [instantiate(x) for x in cfg] - return ListConfig(lst, flags={"allow_objects": True}) - if isinstance(cfg, list): - # Specialize for list, because many classes take - # list[objects] as arguments, such as ResNet, DatasetMapper - return [instantiate(x) for x in cfg] - - # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), - # instantiate it to the actual dataclass. - if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type): - return OmegaConf.to_object(cfg) - - if isinstance(cfg, abc.Mapping) and "_target_" in cfg: - # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, - # but faster: https://github.com/facebookresearch/hydra/issues/1200 - cfg = {k: instantiate(v) for k, v in cfg.items()} - cls = cfg.pop("_target_") - cls = instantiate(cls) - - if isinstance(cls, str): - cls_name = cls - cls = locate(cls_name) - assert cls is not None, cls_name - else: - try: - cls_name = cls.__module__ + "." + cls.__qualname__ - except Exception: - # target could be anything, so the above could fail - cls_name = str(cls) - assert callable(cls), f"_target_ {cls} does not define a callable object" - try: - # override config with kwargs - instantiate_kwargs = {} - instantiate_kwargs.update(cfg) - instantiate_kwargs.update(kwargs) - return cls(*args, **instantiate_kwargs) - except TypeError: - logger = logging.getLogger(__name__) - logger.error(f"Error when instantiating {cls_name}!") - raise - return cfg # return as-is if don't know what to do diff --git a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/lazy.py b/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/lazy.py deleted file mode 100644 index 6db66f42d7f2f2870811c86c0d2f8b748d3b1c69..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/lazy.py +++ /dev/null @@ -1,430 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ast -import builtins -import collections.abc as abc -import importlib -import inspect -import logging -import os -import pickle -import uuid -from collections import OrderedDict -from contextlib import contextmanager -from copy import deepcopy -from dataclasses import is_dataclass -from typing import Any, Dict, List, Tuple, Union - -import attrs -import yaml -from omegaconf import DictConfig, ListConfig, OmegaConf - -from cosmos_transfer1.utils.lazy_config.file_io import PathManager -from cosmos_transfer1.utils.lazy_config.registry import _convert_target_to_string - -try: - import dill as dill_pickle -except ImportError: - dill_pickle = None -try: - import cloudpickle -except ImportError: - cloudpickle = None - -__all__ = ["LazyCall", "LazyConfig"] - - -def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]: - return OrderedDict(sorted(d.items(), key=lambda x: x[0])) - - -def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode: - return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) - - -def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]: - if isinstance(obj, dict): - return sort_dict({k: sort_recursive(v) for k, v in obj.items()}) - elif isinstance(obj, list): - return [sort_recursive(item) for item in obj] - return obj - - -yaml.add_representer(OrderedDict, dict_representer) - - -def get_default_params(cls_or_func): - if callable(cls_or_func): - # inspect signature for function - signature = inspect.signature(cls_or_func) - else: - # inspect signature for class - signature = inspect.signature(cls_or_func.__init__) - params = signature.parameters - default_params = { - name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty - } - return default_params - - -class LazyCall: - """ - Wrap a callable so that when it's called, the call will not be executed, - but returns a dict that describes the call. - - LazyCall object has to be called with only keyword arguments. Positional - arguments are not yet supported. - - Examples: - :: - from detectron2.config import instantiate, LazyCall - - layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) - layer_cfg.out_channels = 64 # can edit it afterwards - layer = instantiate(layer_cfg) - """ - - def __init__(self, target): - if not (callable(target) or isinstance(target, (str, abc.Mapping))): - raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}") - self._target = target - - def __call__(self, **kwargs): - if is_dataclass(self._target) or attrs.has(self._target): - # omegaconf object cannot hold dataclass type - # https://github.com/omry/omegaconf/issues/784 - target = _convert_target_to_string(self._target) - else: - target = self._target - kwargs["_target_"] = target - - _final_params = get_default_params(self._target) - _final_params.update(kwargs) - - return DictConfig(content=_final_params, flags={"allow_objects": True}) - - -def _visit_dict_config(cfg, func): - """ - Apply func recursively to all DictConfig in cfg. - """ - if isinstance(cfg, DictConfig): - func(cfg) - for v in cfg.values(): - _visit_dict_config(v, func) - elif isinstance(cfg, ListConfig): - for v in cfg: - _visit_dict_config(v, func) - - -def _validate_py_syntax(filename): - # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py - with PathManager.open(filename, "r") as f: - content = f.read() - try: - ast.parse(content) - except SyntaxError as e: - raise SyntaxError(f"Config file {filename} has syntax error!") from e - - -def _cast_to_config(obj): - # if given a dict, return DictConfig instead - if isinstance(obj, dict): - return DictConfig(obj, flags={"allow_objects": True}) - return obj - - -_CFG_PACKAGE_NAME = "detectron2._cfg_loader" -""" -A namespace to put all imported config into. -""" - - -def _random_package_name(filename): - # generate a random package name when loading config files - return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) - - -@contextmanager -def _patch_import(): - """ - Enhance relative import statements in config files, so that they: - 1. locate files purely based on relative location, regardless of packages. - e.g. you can import file without having __init__ - 2. do not cache modules globally; modifications of module states has no side effect - 3. support other storage system through PathManager, so config files can be in the cloud - 4. imported dict are turned into omegaconf.DictConfig automatically - """ - old_import = builtins.__import__ - - def find_relative_file(original_file, relative_import_path, level): - # NOTE: "from . import x" is not handled. Because then it's unclear - # if such import should produce `x` as a python module or DictConfig. - # This can be discussed further if needed. - relative_import_err = """ -Relative import of directories is not allowed within config files. -Within a config file, relative import can only import other config files. -""".replace( - "\n", " " - ) - if not len(relative_import_path): - raise ImportError(relative_import_err) - - cur_file = os.path.dirname(original_file) - for _ in range(level - 1): - cur_file = os.path.dirname(cur_file) - cur_name = relative_import_path.lstrip(".") - for part in cur_name.split("."): - cur_file = os.path.join(cur_file, part) - if not cur_file.endswith(".py"): - cur_file += ".py" - if not PathManager.isfile(cur_file): - cur_file_no_suffix = cur_file[: -len(".py")] - if PathManager.isdir(cur_file_no_suffix): - raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) - else: - raise ImportError( - f"Cannot import name {relative_import_path} from " f"{original_file}: {cur_file} does not exist." - ) - return cur_file - - def new_import(name, globals=None, locals=None, fromlist=(), level=0): - if ( - # Only deal with relative imports inside config files - level != 0 - and globals is not None - and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) - ): - cur_file = find_relative_file(globals["__file__"], name, level) - _validate_py_syntax(cur_file) - spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file) - module = importlib.util.module_from_spec(spec) - module.__file__ = cur_file - with PathManager.open(cur_file) as f: - content = f.read() - exec(compile(content, cur_file, "exec"), module.__dict__) - for name in fromlist: # turn imported dict into DictConfig automatically - val = _cast_to_config(module.__dict__[name]) - module.__dict__[name] = val - return module - return old_import(name, globals, locals, fromlist=fromlist, level=level) - - builtins.__import__ = new_import - yield new_import - builtins.__import__ = old_import - - -class LazyConfig: - """ - Provide methods to save, load, and overrides an omegaconf config object - which may contain definition of lazily-constructed objects. - """ - - @staticmethod - def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): - """ - Similar to :meth:`load()`, but load path relative to the caller's - source file. - - This has the same functionality as a relative import, except that this method - accepts filename as a string, so more characters are allowed in the filename. - """ - caller_frame = inspect.stack()[1] - caller_fname = caller_frame[0].f_code.co_filename - assert caller_fname != "", "load_rel Unable to find caller" - caller_dir = os.path.dirname(caller_fname) - filename = os.path.join(caller_dir, filename) - return LazyConfig.load(filename, keys) - - @staticmethod - def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): - """ - Load a config file. - - Args: - filename: absolute path or relative path w.r.t. the current working directory - keys: keys to load and return. If not given, return all keys - (whose values are config objects) in a dict. - """ - has_keys = keys is not None - filename = filename.replace("/./", "/") # redundant - if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: - raise ValueError(f"Config file {filename} has to be a python or yaml file.") - if filename.endswith(".py"): - _validate_py_syntax(filename) - - with _patch_import(): - # Record the filename - module_namespace = { - "__file__": filename, - "__package__": _random_package_name(filename), - } - with PathManager.open(filename) as f: - content = f.read() - # Compile first with filename to: - # 1. make filename appears in stacktrace - # 2. make load_rel able to find its parent's (possibly remote) location - exec(compile(content, filename, "exec"), module_namespace) - - ret = module_namespace - else: - with PathManager.open(filename) as f: - obj = yaml.unsafe_load(f) - ret = OmegaConf.create(obj, flags={"allow_objects": True}) - - if has_keys: - if isinstance(keys, str): - return _cast_to_config(ret[keys]) - else: - return tuple(_cast_to_config(ret[a]) for a in keys) - else: - if filename.endswith(".py"): - # when not specified, only load those that are config objects - ret = DictConfig( - { - name: _cast_to_config(value) - for name, value in ret.items() - if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_") - }, - flags={"allow_objects": True}, - ) - return ret - - @staticmethod - def save_pkl(cfg, filename: str) -> str: - """ - Saves a Config object to a file using pickle serialization. This method is typically used - when the configuration object contains complex objects, such as lambdas, that are not supported by - simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration - object before serialization to ensure that the original object remains unmodified. - - Args: - cfg: A Config object to be serialized and saved. - filename: The path and name of the file where the configuration should be saved. The function - assumes the file extension indicates a pickle format (e.g., .pkl). - - Returns: - str: The filename to which the configuration was saved. This can be used to verify the file location - or log the outcome. - - Notes: - - The function logs a warning if the configuration is successfully saved using pickle. - - If saving fails, an error is logged with the exception details. - """ - logger = logging.getLogger(__name__) - try: - cfg = deepcopy(cfg) - except Exception: - pass - - try: - with PathManager.open(filename, "wb") as f: - pickle.dump(cfg, f) - logger.warning(f"Config is saved using pickle at {filename}.") - except Exception as e: - logger.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead") - if dill_pickle: - try: - with PathManager.open(filename, "wb") as f: - pickle.dump(dill_pickle.dumps(cfg, recurse=True), f) - logger.warning(f"Config is saved using dill at {filename}.") - except Exception as e: - logger.error(f"Failed to save config to {filename}: {e}.") - if cloudpickle: - try: - with PathManager.open(filename, "wb") as f: - pickle.dump(cloudpickle.dumps(cfg), f) - logger.warning(f"Config is saved using cloudpickle at {filename}.") - except Exception as e: - logger.error(f"Failed to save config to {filename}: {e}.") - else: - logger.error("cloudpickle is not available. Cannot save the config.") - raise e - - return filename - - @staticmethod - def save_yaml(cfg, filename: str) -> str: - """ - Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization. - - Args: - cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types. - filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'. - - Returns: - str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome. - - Notes: - - The function logs a warning if the configuration is successfully saved using YAML. - - If saving fails, an error is logged with the exception details. - """ - logger = logging.getLogger(__name__) - try: - cfg = deepcopy(cfg) - except Exception: - pass - - # Define a function to check if an item is serializable to YAML - def is_serializable(item): - try: - OmegaConf.to_yaml(item) - return True - except Exception as e: - return False - - # Function to convert unserializable items to strings - def serialize_config(config): - if isinstance(config, DictConfig): - for key, value in config.items(): - if isinstance(value, (DictConfig, ListConfig)): - try: - if "_target_" in value: - default_params = get_default_params(value["_target_"]) - for default_key, default_v in default_params.items(): - if default_key not in value: - value[default_key] = default_v - except Exception as e: - logger.error(f"Failed to add default argument values: {e}") - - serialize_config(value) - else: - if not is_serializable(value) and value is not None: - config[key] = str(value) - elif isinstance(config, ListConfig): - for i, item in enumerate(config): - if isinstance(item, (DictConfig, ListConfig)): - serialize_config(item) - else: - if not is_serializable(item) and item is not None: - config[i] = str(item) - else: - raise NotImplementedError("Input config must be a DictConfig or ListConfig.") - return config - - # Convert Config object to a DictConfig object. - config_dict = attrs.asdict(cfg) - config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) - - # Serialize the DictConfig object by converting non-serializable objects to strings. - config_omegaconf = serialize_config(config_omegaconf) - - config_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True) - sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict) - with open(filename, "w") as f: - yaml.dump(sorted_config, f, default_flow_style=False) - logger.warning(f"Config is saved using omegaconf at {filename}.") - return filename diff --git a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/omegaconf_patch.py b/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/omegaconf_patch.py deleted file mode 100644 index 39dca42a0a71383de919b750cedf2606faae206d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/omegaconf_patch.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, List, Union - -from omegaconf import OmegaConf -from omegaconf.base import DictKeyType, SCMode -from omegaconf.dictconfig import DictConfig # pragma: no cover - - -def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: - """ - Converts an OmegaConf configuration object to a native Python container (dict or list), unless - the configuration is specifically created by LazyCall, in which case the original configuration - is returned directly. - - This function serves as a modification of the original `to_object` method from OmegaConf, - preventing DictConfig objects created by LazyCall from being automatically converted to Python - dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended - structure and behavior. - - Differences from OmegaConf's original `to_object`: - - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall. - - Reference: - - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595 - - Args: - cfg (Any): The OmegaConf configuration object to convert. - - Returns: - Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if - `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`. - - Examples: - >>> cfg = DictConfig({"key": "value", "_target_": "Model"}) - >>> to_object(cfg) - DictConfig({"key": "value", "_target_": "Model"}) - - >>> cfg = DictConfig({"list": [1, 2, 3]}) - >>> to_object(cfg) - {'list': [1, 2, 3]} - """ - if isinstance(cfg, DictConfig) and "_target_" in cfg.keys(): - return cfg - - return OmegaConf.to_container( - cfg=cfg, - resolve=True, - throw_on_missing=True, - enum_to_str=False, - structured_config_mode=SCMode.INSTANTIATE, - ) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/registry.py b/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/registry.py deleted file mode 100644 index 7c09eb428a97927d5f0407e2328a3f43afbf38fc..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/lazy_config/registry.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pydoc -from typing import Any - -""" -`locate` provide ways to map a string (typically found -in config files) to callable objects. -""" - -__all__ = ["locate"] - - -def _convert_target_to_string(t: Any) -> str: - """ - Inverse of ``locate()``. - - Args: - t: any object with ``__module__`` and ``__qualname__`` - """ - module, qualname = t.__module__, t.__qualname__ - - # Compress the path to this object, e.g. ``module.submodule._impl.class`` - # may become ``module.submodule.class``, if the later also resolves to the same - # object. This simplifies the string, and also is less affected by moving the - # class implementation. - module_parts = module.split(".") - for k in range(1, len(module_parts)): - prefix = ".".join(module_parts[:k]) - candidate = f"{prefix}.{qualname}" - try: - if locate(candidate) is t: - return candidate - except ImportError: - pass - return f"{module}.{qualname}" - - -def locate(name: str) -> Any: - """ - Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, - such as "module.submodule.class_name". - - Raise Exception if it cannot be found. - """ - obj = pydoc.locate(name) - - # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly - # by pydoc.locate. Try a private function from hydra. - if obj is None: - try: - # from hydra.utils import get_method - will print many errors - from hydra.utils import _locate - except ImportError as e: - raise ImportError(f"Cannot dynamically locate object {name}!") from e - else: - obj = _locate(name) # it raises if fails - - return obj diff --git a/cosmos-transfer1/cosmos_transfer1/utils/log.py b/cosmos-transfer1/cosmos_transfer1/utils/log.py deleted file mode 100644 index 2c4975a0f2e72287bbcdd65e45f8e72556660e1f..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/log.py +++ /dev/null @@ -1,152 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import atexit -import os -import sys -from typing import Any, Optional - -import torch.distributed as dist -from loguru._logger import Core, Logger - -RANK0_ONLY = True -LEVEL = os.environ.get("LOGURU_LEVEL", "INFO") - -logger = Logger( - core=Core(), - exception=None, - depth=1, - record=False, - lazy=False, - colors=False, - raw=False, - capture=True, - patchers=[], - extra={}, -) - -atexit.register(logger.remove) - - -def _add_relative_path(record: dict[str, Any]) -> None: - start = os.getcwd() - record["extra"]["relative_path"] = os.path.relpath(record["file"].path, start) - - -*options, _, extra = logger._options # type: ignore -logger._options = tuple([*options, [_add_relative_path], extra]) # type: ignore - - -def init_loguru_stdout() -> None: - logger.remove() - machine_format = get_machine_format() - message_format = get_message_format() - logger.add( - sys.stdout, - level=LEVEL, - format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", - filter=_rank0_only_filter, - ) - - -def get_machine_format() -> str: - node_id = os.environ.get("NGC_ARRAY_INDEX", "0") - num_nodes = int(os.environ.get("NGC_ARRAY_SIZE", "1")) - machine_format = "" - rank = 0 - if dist.is_available(): - if not RANK0_ONLY and dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - machine_format = ( - f"[Node{node_id:<3}/{num_nodes:<3}][RANK{rank:<5}/{world_size:<5}]" + "[{process.name:<8}]| " - ) - return machine_format - - -def init_loguru_file(path: str) -> None: - machine_format = get_machine_format() - message_format = get_message_format() - logger.add( - path, - encoding="utf8", - level=LEVEL, - format="[{time:MM-DD HH:mm:ss}|" f"{machine_format}" f"{message_format}", - rotation="100 MB", - filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY, - enqueue=True, - ) - - -def get_message_format() -> str: - message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}" - return message_format - - -def _rank0_only_filter(record: Any) -> bool: - is_rank0 = record["extra"].get("rank0_only", True) - if _get_rank() == 0 and is_rank0: - return True - if not is_rank0: - record["message"] = f"[RANK {_get_rank()}]" + record["message"] - return not is_rank0 - - -def trace(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) - - -def debug(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) - - -def info(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) - - -def success(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) - - -def warning(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) - - -def error(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) - - -def critical(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) - - -def exception(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) - - -def _get_rank(group: Optional[dist.ProcessGroup] = None) -> int: - """Get the rank (GPU device) of the worker. - - Returns: - rank (int): The rank of the worker. - """ - rank = 0 - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank(group) - return rank - - -# Execute at import time. -init_loguru_stdout() diff --git a/cosmos-transfer1/cosmos_transfer1/utils/misc.py b/cosmos-transfer1/cosmos_transfer1/utils/misc.py deleted file mode 100644 index 9bb12466c9fd4ca77253a252803c722f59bb3b16..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/misc.py +++ /dev/null @@ -1,323 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import base64 -import collections -import collections.abc -import functools -import json -import os -import random -import tempfile -import time -from contextlib import ContextDecorator -from typing import Any, Callable, List, Tuple, TypeVar - -import cv2 -import numpy as np -import termcolor -import torch -from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor.api import DTensor - -from cosmos_transfer1.utils import distributed, log - - -def extract_video_frames(video_path, number_of_frames=2): - cap = cv2.VideoCapture(video_path) - frame_paths = [] - - temp_dir = tempfile.gettempdir() - for i in range(number_of_frames): # Extract first two frames - ret, frame = cap.read() - if not ret: - break # Stop if no more frames - - temp_path = os.path.join(temp_dir, f"frame_{i+1}.png") - cv2.imwrite(temp_path, frame) - frame_paths.append(temp_path) - - cap.release() - return frame_paths - - -def image_to_base64(image_path): - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode("utf-8") - - -def to( - data: Any, - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - memory_format: torch.memory_format = torch.preserve_format, -) -> Any: - """Recursively cast data into the specified device, dtype, and/or memory_format. - - The input data can be a tensor, a list of tensors, a dict of tensors. - See the documentation for torch.Tensor.to() for details. - - Args: - data (Any): Input data. - device (str | torch.device): GPU device (default: None). - dtype (torch.dtype): data type (default: None). - memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). - - Returns: - data (Any): Data cast to the specified device, dtype, and/or memory_format. - """ - assert ( - device is not None or dtype is not None or memory_format is not None - ), "at least one of device, dtype, memory_format should be specified" - if isinstance(data, torch.Tensor): - is_cpu = (isinstance(device, str) and device == "cpu") or ( - isinstance(device, torch.device) and device.type == "cpu" - ) - data = data.to( - device=device, - dtype=dtype, - memory_format=memory_format, - non_blocking=(not is_cpu), - ) - return data - elif isinstance(data, collections.abc.Mapping): - return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): - return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) - else: - return data - - -def get_local_tensor_if_DTensor(tensor: torch.Tensor | DTensor) -> torch.tensor: - if isinstance(tensor, DTensor): - local = tensor.to_local() - # As per PyTorch documentation, if the communication is not finished yet, we need to wait for it to finish - # https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.DTensor.to_local - if isinstance(local, AsyncCollectiveTensor): - return local.wait() - else: - return local - return tensor - - -def serialize(data: Any) -> Any: - """Serialize data by hierarchically traversing through iterables. - - Args: - data (Any): Input data. - - Returns: - data (Any): Serialized data. - """ - if isinstance(data, collections.abc.Mapping): - return type(data)({key: serialize(data[key]) for key in data}) - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): - return type(data)([serialize(elem) for elem in data]) - else: - try: - json.dumps(data) - except TypeError: - data = str(data) - return data - - -def print_environ_variables(env_vars: list[str]) -> None: - """Print a specific list of environment variables. - - Args: - env_vars (list[str]): List of specified environment variables. - """ - for env_var in env_vars: - if env_var in os.environ: - log.info(f"Environment variable {Color.green(env_var)}: {Color.yellow(os.environ[env_var])}") - else: - log.warning(f"Environment variable {Color.green(env_var)} not set!") - - -def set_random_seed(seed: int, by_rank: bool = False) -> None: - """Set random seed. This includes random, numpy, Pytorch. - - Args: - seed (int): Random seed. - by_rank (bool): if true, each GPU will use a different random seed. - """ - if by_rank: - seed += distributed.get_rank() - log.info(f"Using random seed {seed}.") - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) # sets seed on the current CPU & all GPUs - - -def arch_invariant_rand( - shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None -): - """Produce a GPU-architecture-invariant randomized Torch tensor. - - Args: - shape (list or tuple of ints): Output tensor shape. - dtype (torch.dtype): Output tensor type. - device (torch.device): Device holding the output. - seed (int): Optional randomization seed. - - Returns: - tensor (torch.tensor): Randomly-generated tensor. - """ - # Create a random number generator, optionally seeded - rng = np.random.RandomState(seed) - - # # Generate random numbers using the generator - random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution - - # Convert to torch tensor and return - return torch.from_numpy(random_array).to(dtype=dtype, device=device) - - -T = TypeVar("T", bound=Callable[..., Any]) - - -class timer(ContextDecorator): # noqa: N801 - """Simple timer for timing the execution of code. - - It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. - - Example: - def func_a(): - time.sleep(1) - with timer("func_a"): - func_a() - - @timer("func_b) - def func_b(): - time.sleep(1) - func_b() - """ - - def __init__(self, context: str, debug: bool = False): - self.context = context - self.debug = debug - - def __enter__(self) -> None: - self.tic = time.time() - - def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 - time_spent = time.time() - self.tic - if self.debug: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - else: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - - def __call__(self, func: T) -> T: - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - tic = time.time() - result = func(*args, **kwargs) - time_spent = time.time() - tic - if self.debug: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - else: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - return result - - return wrapper # type: ignore - - -class TrainingTimer: - """Timer for timing the execution of code, aggregating over multiple training iterations. - - It is used as a context manager to measure the execution time of code and store the timing results - for each function. The context managers can be nested. - - Attributes: - results (dict): A dictionary to store timing results for various code. - - Example: - timer = Timer() - for i in range(100): - with timer("func_a"): - func_a() - avg_time = sum(timer.results["func_a"]) / len(timer.results["func_a"]) - print(f"func_a() took {avg_time} seconds.") - """ - - def __init__(self) -> None: - self.results = dict() - self.average_results = dict() - self.start_time = [] - self.func_stack = [] - self.reset() - - def reset(self) -> None: - self.results = {key: [] for key in self.results} - - def __enter__(self) -> TrainingTimer: - self.start_time.append(time.time()) - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 - end_time = time.time() - result = end_time - self.start_time.pop() - key = self.func_stack.pop() - self.results.setdefault(key, []) - self.results[key].append(result) - - def __call__(self, func_name: str) -> TrainingTimer: - self.func_stack.append(func_name) - return self - - def __getattr__(self, func_name: str) -> TrainingTimer: - return self.__call__(func_name) - - def nested(self, func_name: str) -> TrainingTimer: - return self.__call__(func_name) - - def compute_average_results(self) -> dict[str, float]: - results = dict() - for key, value_list in self.results.items(): - results[key] = sum(value_list) / len(value_list) - return results - - -def timeout_handler(timeout_period: float, signum: int, frame: int) -> None: - # What to do when the process gets stuck. For now, we simply end the process. - error_message = f"Timeout error: more than {timeout_period} seconds passed since the last iteration." - raise TimeoutError(error_message) - - -class Color: - """A convenience class to colorize strings in the console. - - Example: - import - print("This is {Color.red('important')}.") - """ - - @staticmethod - def red(x: str) -> str: - return termcolor.colored(str(x), color="red") - - @staticmethod - def green(x: str) -> str: - return termcolor.colored(str(x), color="green") - - @staticmethod - def cyan(x: str) -> str: - return termcolor.colored(str(x), color="cyan") - - @staticmethod - def yellow(x: str) -> str: - return termcolor.colored(str(x), color="yellow") diff --git a/cosmos-transfer1/cosmos_transfer1/utils/model.py b/cosmos-transfer1/cosmos_transfer1/utils/model.py deleted file mode 100644 index 54d06ce6f0aadf79dae64240c7219bfc1b4f2083..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/model.py +++ /dev/null @@ -1,136 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from cosmos_transfer1.utils.lazy_config import LazyDict, instantiate - - -class Model(torch.nn.Module): - """The base model class. It is inherited from torch.nn.Module. - - All models should inherit Model. It should include the implementions for all the - computation graphs. All inheriting child classes should implement the following methods: - - training_step(): The training step of the model, including the loss computation. - - validation_step(): The validation step of the model, including the loss computation. - - forward(): The computation graph for model inference. - The following methods have default implementations in Model: - - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. - """ - - def __init__(self) -> None: - super().__init__() - self.on_model_init_start(set_barrier=False) - - def init_optimizer_scheduler( - self, optimizer_config: LazyDict, scheduler_config: LazyDict - ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: - """Creates the optimizer and scheduler for the model. - - Args: - config_model (ModelConfig): The config object for the model. - - Returns: - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - """ - optimizer_config.params = self.parameters() - optimizer = instantiate(optimizer_config) - scheduler_config.optimizer = optimizer - scheduler = instantiate(scheduler_config) - return optimizer, scheduler - - def training_step( - self, data_batch: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """The training step of the model, including the loss computation. - - Args: - data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). - iteration (int): Current iteration number. - - Returns: - output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. - loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). - """ - raise NotImplementedError - - @torch.no_grad() - def validation_step( - self, data_batch: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """The validation step of the model, including the loss computation. - - Args: - data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). - iteration (int): Current iteration number. - - Returns: - output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. - loss (torch.Tensor): The total loss (weighted sum of various losses). - """ - raise NotImplementedError - - @torch.inference_mode() - def forward(self, *args: Any, **kwargs: Any) -> Any: - """The computation graph for model inference. - - Args: - *args: Whatever you decide to pass into the forward method. - **kwargs: Keyword arguments are also possible. - - Return: - Your model's output. - """ - raise NotImplementedError - - def on_model_init_start(self, set_barrier=False) -> None: - return - - def on_model_init_end(self, set_barrier=False) -> None: - return - - def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: - """The model preparation before the training is launched - - Args: - memory_format (torch.memory_format): Memory format of the model. - """ - pass - - def on_before_zero_grad( - self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int - ) -> None: - """Hook before zero_grad() is called. - - Args: - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - iteration (int): Current iteration number. - """ - pass - - def on_after_backward(self, iteration: int = 0) -> None: - """Hook after loss.backward() is called. - - This method is called immediately after the backward pass, allowing for custom operations - or modifications to be performed on the gradients before the optimizer step. - - Args: - iteration (int): Current iteration number. - """ - pass diff --git a/cosmos-transfer1/cosmos_transfer1/utils/parallel_state_helper.py b/cosmos-transfer1/cosmos_transfer1/utils/parallel_state_helper.py deleted file mode 100644 index f531ab00c9d45a7dbf5015a43147bf635d72c5ec..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/parallel_state_helper.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core import parallel_state - - -def is_tp_cp_pp_rank0(): - return ( - parallel_state.get_tensor_model_parallel_rank() == 0 - and parallel_state.get_pipeline_model_parallel_rank() == 0 - and parallel_state.get_context_parallel_rank() == 0 - ) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/regional_prompting_utils.py b/cosmos-transfer1/cosmos_transfer1/utils/regional_prompting_utils.py deleted file mode 100644 index 4014335830b7607265e402a8a4a3ccc6d163f2ab..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/regional_prompting_utils.py +++ /dev/null @@ -1,257 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple, Union - -import matplotlib.pyplot as plt -import torch - -from cosmos_transfer1.utils import log - - -class RegionalPromptProcessor: - """ - Processes regional prompts and creates corresponding masks for attention. - """ - - def __init__(self, max_img_h, max_img_w, max_frames): - self.max_img_h = max_img_h - self.max_img_w = max_img_w - self.max_frames = max_frames - - def create_region_masks_from_boxes( - self, - bounding_boxes: List[List[float]], - batch_size: int, - time_dim: int, - height: int, - width: int, - device: torch.device, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Create region masks from bounding boxes [x1, y1, x2, y2] in normalized coordinates (0-1). - - Returns: - region_masks: Tensor of shape (B, R, T, H, W) with values between 0 and 1 - """ - num_regions = len(bounding_boxes) - region_masks = torch.zeros( - batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16 - ) - - for r, box in enumerate(bounding_boxes): - # Convert normalized coordinates to pixel coordinates - x1, y1, x2, y2 = box - x1 = int(x1 * width) - y1 = int(y1 * height) - x2 = int(x2 * width) - y2 = int(y2 * height) - - # Create mask for this region - region_masks[:, r, :, y1:y2, x1:x2] = 1.0 - - return region_masks - - def create_region_masks_from_segmentation( - self, - segmentation_maps: List[torch.Tensor], - batch_size: int, - time_dim: int, - height: int, - width: int, - device: torch.device, - ) -> torch.Tensor: - """ - Create masks from binary segmentation maps. - - Args: - segmentation_maps: List of Tensors, each of shape (T, H, W) with binary values - - Returns: - region_masks: Tensor of shape (B, R, T, H, W) with binary values - """ - num_regions = len(segmentation_maps) - region_masks = torch.zeros( - batch_size, num_regions, time_dim, height, width, device=device, dtype=torch.bfloat16 - ) - - for r, seg_map in enumerate(segmentation_maps): - # Clip to 121 frames if longer - if seg_map.shape[0] > time_dim: - log.info(f"clipping segmentation map to {time_dim} frames") - seg_map = seg_map[:time_dim] - region_masks[:, r] = seg_map.float() - - return region_masks - - def visualize_region_masks( - self, region_masks: torch.Tensor, save_path: str, time_dim: int, height: int, width: int - ) -> None: - """ - Visualize region masks for debugging purposes. - - Args: - region_masks: Tensor of shape (B, R, T*H*W) - save_path: Path to save the visualization - time_dim: Number of frames - height: Height in latent space - width: Width in latent space - """ - - B, R, T, H, W = region_masks.shape - reshaped_masks = region_masks - - # Create figure - fig, axes = plt.subplots(R, 1, figsize=(10, 3 * R)) - if R == 1: - axes = [axes] - for r in range(R): - axes[r].imshow(reshaped_masks[r, time_dim // 2].cpu().numpy(), cmap="gray") - axes[r].set_title(f"Region {r+1} Mask (Middle Frame)") - plt.tight_layout() - plt.savefig(save_path) - plt.close() - - -def compress_segmentation_map(segmentation_map, compression_factor): - # Handle both [T,H,W] and [C,T,H,W] formats - if len(segmentation_map.shape) == 4: # [C,T,H,W] format - C, T, H, W = segmentation_map.shape - # Assuming first channel contains the main segmentation mask - # Can be modified based on specific requirements - segmentation_map = segmentation_map[0] # Take first channel, now [T,H,W] - - # Add batch and channel dimensions [1, 1, T, H, W] - expanded_map = segmentation_map.unsqueeze(0).unsqueeze(0) - T, H, W = segmentation_map.shape - new_H = H // compression_factor - new_W = W // compression_factor - - compressed_map = torch.nn.functional.interpolate( - expanded_map, size=(T, new_H, new_W), mode="trilinear", align_corners=False - ) - - return compressed_map.squeeze(0).squeeze(0) - - -def prepare_regional_prompts( - model, - global_prompt: Union[str, torch.Tensor], - regional_prompts: torch.Tensor, - region_definitions: List[Union[List[float], str]], - batch_size: int, - time_dim: int, - height: int, - width: int, - device: torch.device, - cache_dir: str = None, - local_files_only: bool = False, - visualize_masks: bool = False, - visualization_path: str = None, - compression_factor: int = 1, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Prepare regional prompts and masks for inference. - - Args: - model: DiT model - global_prompt: Global text prompt or pre-computed embedding - regional_prompts: List of regional text prompts - region_definitions: List of bounding boxes [x1, y1, x2, y2] or segmentation map - batch_size: Batch size - time_dim: Number of frames - height: Height in latent space - width: Width in latent space - device: Device to create tensors on - cache_dir: Cache directory for text encoder - local_files_only: Whether to use only local files for text encoder - visualize_masks: Whether to visualize the region masks for debugging - visualization_path: Path to save the visualization - - Returns: - global_context: Global prompt embedding - regional_contexts: List of regional prompt embeddings - region_masks: Region masks tensor with values between 0 and 1 - """ - processor = RegionalPromptProcessor(max_img_h=height, max_img_w=width, max_frames=time_dim) - - # Validate that we have matching number of prompts and region definitions - if len(regional_prompts) != len(region_definitions): - raise ValueError( - f"Number of regional prompts ({len(regional_prompts)}) must match " - f"total number of region definitions ({len(region_definitions)})" - ) - - # Track which prompts correspond to which region types while maintaining order - box_prompts = [] - seg_prompts = [] - prompt_idx = 0 - - segmentation_maps: List[torch.Tensor] = [] - region_definitions_list: List[List[float]] = [] - # Maintain correspondence between prompts and region definitions - for region_definition in region_definitions: - if isinstance(region_definition, str): - segmentation_map = torch.load(region_definition, weights_only=False) - # Validate segmentation map dimensions - if len(segmentation_map.shape) not in [3, 4]: - raise ValueError( - f"Segmentation map should have shape [T,H,W] or [C,T,H,W], got shape {segmentation_map.shape}" - ) - - segmentation_map = compress_segmentation_map(segmentation_map, compression_factor) - log.info(f"segmentation_map shape: {segmentation_map.shape}") - segmentation_maps.append(segmentation_map) - seg_prompts.append(regional_prompts[prompt_idx]) - elif isinstance(region_definition, list): - region_definitions_list.append(region_definition) - box_prompts.append(regional_prompts[prompt_idx]) - else: - raise ValueError(f"Region definition format not recognized: {type(region_definition)}") - prompt_idx += 1 - - # Update regional_prompts to maintain correct ordering - regional_prompts = box_prompts + seg_prompts - region_masks_boxes = processor.create_region_masks_from_boxes( - region_definitions_list, batch_size, time_dim, height, width, device - ) - region_masks_segmentation = processor.create_region_masks_from_segmentation( - segmentation_maps, batch_size, time_dim, height, width, device - ) - region_masks = torch.cat([region_masks_boxes, region_masks_segmentation], dim=1) - - if visualize_masks and visualization_path: - processor.visualize_region_masks(region_masks, visualization_path, time_dim, height, width) - - if isinstance(global_prompt, str): - pass - elif isinstance(global_prompt, torch.Tensor): - global_context = global_prompt.to(dtype=torch.bfloat16) - else: - raise ValueError("Global prompt format not recognized.") - - regional_contexts = [] - for regional_prompt in regional_prompts: - if isinstance(regional_prompt, str): - raise ValueError(f"Regional prompt should be converted to embedding: {type(regional_prompt)}") - elif isinstance(regional_prompt, torch.Tensor): - regional_context = regional_prompt.to(dtype=torch.bfloat16) - else: - raise ValueError(f"Regional prompt format not recognized: {type(regional_prompt)}") - - regional_contexts.append(regional_context) - - regional_contexts = torch.stack(regional_contexts, dim=1) - return global_context, regional_contexts, region_masks diff --git a/cosmos-transfer1/cosmos_transfer1/utils/t5_text_encoder.py b/cosmos-transfer1/cosmos_transfer1/utils/t5_text_encoder.py deleted file mode 100644 index 0739d419d8c449a07f763fae49b1284e762157e1..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/t5_text_encoder.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple, Union - -import torch -import transformers -from transformers import T5EncoderModel, T5TokenizerFast - -from cosmos_transfer1.utils import log - -transformers.logging.set_verbosity_error() - - -class CosmosT5TextEncoder(torch.nn.Module): - """Handles T5 text encoding operations.""" - - def __init__(self, model_name: str = "google-t5/t5-11b", device: str = "cuda", cache_dir: str = "~/.cache"): - """Initializes the T5 tokenizer and encoder. - - Args: - model_name: The name of the T5 model to use. - device: The device to use for computations. - """ - super().__init__() - try: - self.tokenizer = T5TokenizerFast.from_pretrained(cache_dir, cache_dir=cache_dir) - self.text_encoder = T5EncoderModel.from_pretrained(cache_dir, cache_dir=cache_dir).to(device) - except Exception as e: - log.warning(f"Failed to load T5 model using cache_dir '{cache_dir}', falling back to default location: {e}") - self.tokenizer = T5TokenizerFast.from_pretrained(model_name) - self.text_encoder = T5EncoderModel.from_pretrained(model_name).to(device) - self.text_encoder.eval() - self.device = device - - @torch.inference_mode() - def encode_prompts( - self, prompts: Union[str, List[str]], max_length: int = 512 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encodes text prompts into hidden state representations using a T5 encoder. - - This function tokenizes the input prompts, processes them through a T5 text encoder, - and returns the last hidden states. The encoded outputs beyond the actual sequence - length are zero-padded. All prompts in a batch are padded to max_length. - - Args: - prompts: Input text to encode. Can be a single string or a list of strings. - max_length: Maximum sequence length for tokenization and padding. Longer - sequences will be truncated. Defaults to 512. - return_mask: If True, returns the attention mask along with encoded text. - Defaults to False. - - Returns: - If return_mask is False: - torch.Tensor: Encoded text embeddings of shape (batch_size, max_length, hidden_size). - If return_mask is True: - tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Encoded text embeddings of shape (batch_size, max_length, hidden_size) - - Attention mask of shape (batch_size, max_length) as boolean tensor - - Raises: - ValueError: If the input prompts list is empty. - - Example: - >>> encoder = CosmosT5TextEncoder() - >>> prompts = ["Hello world", "Another example"] - >>> embeddings = encoder.encode_prompts(prompts, max_length=128) - """ - if isinstance(prompts, str): - prompts = [prompts] - - if not prompts: - raise ValueError("The input prompt list is empty.") - - batch_encoding = self.tokenizer.batch_encode_plus( - prompts, - return_tensors="pt", - truncation=True, - padding="max_length", - max_length=max_length, - return_length=True, - return_offsets_mapping=False, - ) - - input_ids = batch_encoding.input_ids.to(self.device) - attn_mask = batch_encoding.attention_mask.to(self.device) - - outputs = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask) - - encoded_text = outputs.last_hidden_state - lengths = attn_mask.sum(dim=1).cpu() - - for batch_id in range(encoded_text.shape[0]): - encoded_text[batch_id][lengths[batch_id] :] = 0 - - return encoded_text, attn_mask diff --git a/cosmos-transfer1/cosmos_transfer1/utils/trainer.py b/cosmos-transfer1/cosmos_transfer1/utils/trainer.py deleted file mode 100644 index 4efbfcdfc83344af58bb2ac0398cce651f445512..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/trainer.py +++ /dev/null @@ -1,281 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import os -import signal - -import torch -import torch.distributed as dist -import torch.utils.data -from megatron.core import parallel_state - -from cosmos_transfer1.utils import callback, distributed, ema, log, misc -from cosmos_transfer1.utils.checkpointer import Checkpointer -from cosmos_transfer1.utils.lazy_config import LazyConfig, instantiate -from cosmos_transfer1.utils.model import Model - - -class Trainer: - """The base trainer class. - - All trainers should inherit Trainer. It contains the basic functionality for model training - (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), - mixed-precision training (fp16/bf16). - - Attributes: - checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. - training_timer (misc.Timer): Timer object to time code blocks and functions. - """ - - def __init__(self, config): - """Constructor of the trainer. - - Args: - config (Config): The config object for the codebase. - """ - super().__init__() - self.config = config - # Set up the distributed computing environment. - with misc.timer("init_distributed"): - distributed.init() - # Set up parallel states. - if hasattr(config.model, "context_parallel_size"): - if config.model_parallel.context_parallel_size > 1: - raise ValueError( - "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " - "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." - ) - else: - log.critical( - "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." - ) - config.model_parallel.context_parallel_size = config.model.context_parallel_size - parallel_state.initialize_model_parallel( - pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, - tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, - context_parallel_size=config.model_parallel.context_parallel_size, - ) - # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism. - # It is not part of the original `parallel_state` API, so we need to set it manually. - parallel_state.sequence_parallel = config.model_parallel.sequence_parallel - if parallel_state.sequence_parallel: - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - - # Create the local job directory, save the config file, and pipe to a local log. - if distributed.is_rank0(): - os.makedirs(config.job.path_local, exist_ok=True) - # Save the config as .pkl for reproducibility. - LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") - # Save the config as .yaml for reading or parsing experiment hyperparameters. - LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") - dist.barrier() - log.init_loguru_file(f"{config.job.path_local}/stdout.log") - if distributed.is_rank0(): - # Print important environment variables and the effective config. - log.info("Config:\n" + config.pretty_print(use_color=True)) - misc.print_environ_variables(["OUTPUT_ROOT"]) - # Set the random seed. If multi-GPU, different ranks are set with different seeds. - misc.set_random_seed(seed=config.trainer.seed, by_rank=True) - # Initialize cuDNN. - torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic - torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark - # Floating-point precision settings. - torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True - # Initialize the callback functions. - self.callbacks = callback.CallBackGroup(config=config, trainer=self) - # Initialize the model checkpointer. - if config.checkpoint.type is None: - self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) - else: - self.checkpointer: Checkpointer = instantiate( - config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks - ) - # Initialize the timer for speed benchmarking. - self.training_timer = misc.TrainingTimer() - # Send a TimeoutError if a training step takes over timeout_period seconds. - signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore - - def train( - self, - model: Model, - dataloader_train: torch.utils.data.DataLoader, - dataloader_val: torch.utils.data.DataLoader, - ) -> None: - """The training function. - - Args: - model (Model): The PyTorch model. - dataloader_train (torch.utils.data.DataLoader): The training data loader. - dataloader_val (torch.utils.data.DataLoader): The validation data loader. - """ - # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. - model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore - model.on_train_start(self.config.trainer.memory_format) - - # Initialize the optimizer, scheduler, and grad_scaler. - self.callbacks.on_optimizer_init_start() - optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) - grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) - self.callbacks.on_optimizer_init_end() - # Load the model checkpoint and get the starting iteration number. - iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) - grad_accum_iter = 0 - log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") - if self.config.trainer.distributed_parallelism == "ddp": - # Create a DDP model wrapper. - model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) - elif self.config.trainer.distributed_parallelism == "fsdp": - model_ddp = model - else: - raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") - log.info("Starting training...") - self.callbacks.on_train_start(model, iteration=iteration) - # Initial validation. - if self.config.trainer.run_validation and iteration == 0: - self.validate(model, dataloader_val, iteration=iteration) - _end_training = False - while True: - dataloader_train_iter = iter(dataloader_train) - while True: - self.callbacks.on_before_dataloading(iteration) - with self.training_timer("dataloader_train"): - try: - data_batch = next(dataloader_train_iter) - except StopIteration: - break - self.callbacks.on_after_dataloading(iteration) - # If max_iter is reached, exit the training loop. - if iteration >= self.config.trainer.max_iter: - _end_training = True - break - # Move all tensors in the data batch to GPU device. - data_batch = misc.to(data_batch, device="cuda") - # The actual training step. - self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) - if not model.training: - model_ddp.train() - assert model_ddp.training, "model_ddp is not in training mode." - assert model.training, "model is not in training mode." - output_batch, loss, grad_accum_iter = self.training_step( - model_ddp, - optimizer, - scheduler, - grad_scaler, - data_batch, - iteration=iteration, - grad_accum_iter=grad_accum_iter, - ) - # Do the following when an actual optimizer (update) step has been made. - iteration += 1 - # Save checkpoint. - if iteration % self.config.checkpoint.save_iter == 0: - self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) - self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) - # Validation. - if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: - self.validate(model, dataloader_val, iteration=iteration) - # This iteration is successful; reset the timeout signal. - signal.alarm(self.config.trainer.timeout_period) - if _end_training: - break - log.success("Done with training.") - if iteration % self.config.checkpoint.save_iter != 0: - self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) - self.callbacks.on_train_end(model, iteration=iteration) - self.checkpointer.finalize() - log.info("Cleaning up distributed environment..") - distributed.barrier() - log.info("Cleaning up distributed environment.. Done!") - self.callbacks.on_app_end() - - def training_step( - self, - model_ddp: torch.nn.Module | distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - data: dict[str, torch.Tensor], - iteration: int = 0, - grad_accum_iter: int = 0, - ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: - """The training step. - - Args: - model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare - module, depending on whether distributed training is enabled or not. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). - iteration (int): Current iteration number. - grad_accum_iter (int): Number of gradient accumulation iterations. - - Returns: - output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). - loss (torch.Tensor): The total loss of the training data batch. - """ - # Only let DDP sync gradient at the last iteration of the gradient accumulation window - with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): - with self.training_timer("forward"): - output_batch, loss = model_ddp.training_step(data, iteration) - self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) - with self.training_timer("backward"): - loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) - loss_scaled.backward() - if self.config.trainer.distributed_parallelism == "ddp": - model_ddp.module.on_after_backward() - else: - model_ddp.on_after_backward() - self.callbacks.on_after_backward(model_ddp, iteration=iteration) - grad_accum_iter += 1 - if grad_accum_iter == self.config.trainer.grad_accum_iter: - with self.training_timer("optimizer_step"): - self.callbacks.on_before_optimizer_step( - model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration - ) - grad_scaler.step(optimizer) - grad_scaler.update() - scheduler.step() - self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) - if self.config.trainer.distributed_parallelism == "ddp": - model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) - else: - model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) - optimizer.zero_grad(set_to_none=True) - grad_accum_iter = 0 - return output_batch, loss, grad_accum_iter - - @torch.no_grad() - def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: - """Validate on the full validation dataset. - - Args: - model (Model): The PyTorch model. - dataloader_val (torch.utils.data.DataLoader): The validation data loader. - iteration (int): Current iteration number. - """ - self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) - model.eval() - # Evaluate on the full validation set. - with ema.ema_scope(model, enabled=model.config.ema.enabled): - for val_iter, data_batch in enumerate(dataloader_val): - if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: - break - data_batch = misc.to(data_batch, device="cuda") - self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) - output_batch, loss = model.validation_step(data_batch, iteration) - self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) - self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/cosmos-transfer1/cosmos_transfer1/utils/video_utils.py b/cosmos-transfer1/cosmos_transfer1/utils/video_utils.py deleted file mode 100644 index b4f32755eeb95d12a78b20668e02324cfbbc85e8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/cosmos_transfer1/utils/video_utils.py +++ /dev/null @@ -1,115 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Tuple - -import cv2 -import magic -import numpy as np -import torch - -from cosmos_transfer1.utils import log - -# Supported video extensions and corresponding MIME types -SUPPORTED_VIDEO_TYPES = { - ".mp4": "video/mp4", - ".mkv": "video/x-matroska", - ".mov": "video/quicktime", - ".avi": "video/x-msvideo", - ".webm": "video/webm", - ".flv": "video/x-flv", - ".wmv": "video/x-ms-wmv", -} - - -def video_to_tensor(video_path: str, output_path: str, normalize: bool = True) -> Tuple[torch.Tensor, float]: - """Convert an MP4 video file to a tensor and save it as a .pt file. - Args: - video_path (str): Path to input MP4 video file - output_path (str): Path to save output .pt tensor file - normalize (bool): Whether to normalize pixel values to [-1,1] range (default: True) - - Returns: - Tuple[torch.Tensor, float]: Tuple containing: - - Video tensor in shape [C,T,H,W] - - Video FPS - """ - # Open video file - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise ValueError(f"Failed to open video file: {video_path}") - - # Get video properties - fps = cap.get(cv2.CAP_PROP_FPS) - frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - # Read first frame to get dimensions - ret, frame = cap.read() - if not ret: - raise ValueError(f"Failed to read frames from video: {video_path}") - - height, width = frame.shape[:2] - - # Reset video to beginning - cap.set(cv2.CAP_PROP_POS_FRAMES, 0) - - # Initialize tensor to store frames - frames = [] - - # Read all frames - while True: - ret, frame = cap.read() - if not ret: - break - - # Convert BGR to RGB - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frames.append(frame) - - cap.release() - - log.info(f"frames: {len(frames)}") - # Convert frames to tensor - video_tensor = torch.from_numpy(np.array(frames)) - log.info(f"video_tensor shape: {video_tensor.shape}") - # Reshape from [T,H,W,C] to [C,T,H,W] - video_tensor = video_tensor.permute(3, 0, 1, 2) - - # Normalize if requested - if normalize: - video_tensor = video_tensor.float() / 127.5 - 1.0 - - # Save tensor - os.makedirs(os.path.dirname(output_path), exist_ok=True) - torch.save(video_tensor, output_path) - - return video_tensor, fps - - -def is_valid_video(file_path: str) -> bool: - if not os.path.isfile(file_path): - return False - - ext = os.path.splitext(file_path)[1].lower() - expected_mime = SUPPORTED_VIDEO_TYPES.get(ext) - - if not expected_mime: - return False # Extension not supported - - # Detect MIME type from actual file content - detected_mime = magic.from_file(file_path, mime=True) - - return detected_mime == expected_mime diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b.md deleted file mode 100644 index 166225566ce77bd8ba14ef77cb9def560fae4654..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b.md +++ /dev/null @@ -1,382 +0,0 @@ -# Cosmos-Transfer1: World Generation with Adaptive Multimodal Control - -## Install Cosmos-Transfer1 - -### Environment setup - -Please refer to the Inference section of [INSTALL.md](/INSTALL.md#inference) for instructions on environment setup. - -### Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [Llama-Guard-3-8B terms](https://huggingface.co/meta-llama/Llama-Guard-3-8B) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e): - -```bash -PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ -``` - -Note that this will require about 300GB of free storage. Not all these checkpoints will be used in every generation. - -5. The downloaded files should be in the following structure: - -``` -checkpoints/ -├── nvidia -│ │ -│ ├── Cosmos-Guardrail1 -│ │ ├── README.md -│ │ ├── blocklist/... -│ │ ├── face_blur_filter/... -│ │ └── video_content_safety_filter/... -│ │ -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── 4kupscaler_control.pt -│ │ └── config.json -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ │── Cosmos-Tokenize1-CV8x8x8-720p -│ │ ├── decoder.jit -│ │ ├── encoder.jit -│ │ ├── autoencoder.jit -│ │ └── mean_std.pt -│ │ -│ └── Cosmos-UpsamplePrompt1-12B-Transfer -│ ├── depth -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── README.md -│ ├── segmentation -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── seg_upsampler_example.png -│ └── viscontrol -│ ├── consolidated.safetensors -│ ├── params.json -│ └── tekken.json -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -├── IDEA-Research/... -└── meta-llama/... -``` - -## Sample Commands - -Here's an example command: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir checkpoints \ - --input_video_path path/to/input_video.mp4 \ - --video_save_name output_video \ - --controlnet_specs spec.json \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` - -Cosmos-Transfer1 supports a variety of configurations. You can pass your configuration in a JSON file via the argument `--controlnet_specs`. Let's go over a few examples: - -### Example 1: single control (Edge) - -The following `controlnet_specs` only activates the edge controlnet. - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example1_single_control_edge \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_edge.json \ - --offload_text_encoder_model \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` - -You can also choose to run the inference on multiple GPUs as follows: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0,1,2,3}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=4}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example1_single_control_edge \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_edge.json \ - --offload_text_encoder_model \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` - -This launches `transfer.py` and configures the controlnets for inference according to `assets/inference_cosmos_transfer1_single_control_edge.json`: - -```json -{ - "prompt": "The video is set in a modern, well-lit office environment with a sleek, minimalist design. ...", - "input_video_path" : "assets/example1_input_video.mp4", - "edge": { - "control_weight": 1.0 - } -} -``` - -The input video is a low-resolution 640 × 480 video. - - - -We generate a 960 x 704 video. - - - -You can use our prompt upsampler to convert your short prompt into a longer, more detailed prompt for video generation by using the `--upsample_prompt` argument. - - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example1_single_control_edge_upsampled_prompt \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_edge_short_prompt.json \ - --offload_text_encoder_model \ - --upsample_prompt \ - --offload_prompt_upsampler \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` - - -Input prompt: *Robotic arms hand over a coffee cup to a woman in a modern office.* - -Upsampled prompt: *The video opens with a close-up of a robotic arm holding a coffee cup with a lid, positioned next to a coffee machine. The arm is metallic with a black wrist, and the coffee cup is white with a brown lid. The background shows a modern office environment with a woman in a blue top and black pants standing in the distance. As the video progresses, the robotic arm moves the coffee cup towards the woman, who approaches to receive it. The woman has long hair and is wearing a blue top and black pants. The office has a contemporary design with glass partitions, potted plants, and other office furniture.* - -Here is the generated video using the upsampled prompt. - - - -### Example 2: multimodal control - -The following `controlnet_specs` activates vis, edge, depth, seg controls at the same time and apply uniform spatial weights. - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example2_uniform_weights \ - --controlnet_specs assets/inference_cosmos_transfer1_uniform_weights.json \ - --offload_text_encoder_model \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` - -This launches `transfer.py` and configures the controlnets for inference according to `assets/inference_cosmos_transfer1_uniform_weights.json`: - -```json -{ - "prompt": "The video is set in a modern, well-lit office environment with a sleek, minimalist design. ...", - "input_video_path" : "assets/example1_input_video.mp4", - "vis": { - "control_weight": 0.25 - }, - "edge": { - "control_weight": 0.25 - }, - "depth": { - "input_control": "assets/example1_depth.mp4", - "control_weight": 0.25 - }, - "seg": { - "input_control": "assets/example1_seg.mp4", - "control_weight": 0.25 - } -} -``` - -The output video can be found at `assets/example1_uniform_weights.mp4`. - - -#### Explanation of the controlnet spec -* `prompt` specifies the global prompt that all underlying networks will receive. -* `input_video_path` specifies the input video -* `sigma_max` specifies the level of noise that should be added to the input video before feeding through the base model branch -* The dictionaries `vis`, `edge`, `depth`, and `seg` activate the corresponding controlnet branches. -* The `control_weight` parameter is a number within the range [0, 1] that controls how strongly the controlnet branch should affect the output of the model. The larger the value (closer to 1.0), the more strongly the generated video will adhere to the controlnet input. However, this rididity may come at a cost of quality. Lower (closer to 0) values would give more creative liberty to the model at the cost of reduced adherance. Usually a middleground value, say 0.5, yields optinal results. -* The inputs to each controlnet branch is automatically computed according to the branch: - * `vis` applies bilateral blurring on the input video to compute the `input_control` to that branch - * `edge` uses [Canny Edge Detection](https://docs.opencv.org/4.x/da/d22/tutorial_py_canny.html) to compute the Canny edge `input_control` from the `input_control` - * `depth` uses [DepthAnything](https://github.com/DepthAnything/Depth-Anything-V2) - * `seg` uses [Segment Anything Model 2](https://ai.meta.com/sam2/) for generating the segmentation map as `input_control` from the input video. - -#### Additional Information -- At each spatiotemporal site, if the sum of the control maps across different modalities is greater than one, we apply normalization to the modality weights so that they sum up to one. -- For `depth` and `seg`, if the `input_control` is not provided, we will run DepthAnything2 and GroundingDino+SAM2 on `input_video_path` to generate the corresponding `input_control`. Please see `assets/inference_cosmos_transfer1_uniform_weights_auto.json` as an example. -- For `seg`, `input_control_prompt` can be provided to customize the prompt sent to GroundingDino. We can use ` . ` to separate objects in the `input_control_prompt`, e.g. `robotic arms . woman . cup`, as suggested by [GroundingDino](https://github.com/IDEA-Research/GroundingDINO?tab=readme-ov-file#arrow_forward-demo). If `input_control_prompt` is not provided, `prompt` will be used by default. Please see `assets/inference_cosmos_transfer1_uniform_weights_auto.json` as an example. - -### Example 3: multimodal control with spatiotemporal control map - -The following `controlnet_specs` activates vis, edge, depth, seg controls at the same time and apply spatiotemporal weights. - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example3_spatiotemporal_weights \ - --controlnet_specs assets/inference_cosmos_transfer1_spatiotemporal_weights_auto.json \ - --offload_text_encoder_model \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` - -This launches `transfer.py` and configures the controlnets for inference according to `assets/inference_cosmos_transfer1_spatiotemporal_weights_auto.json`: - -```json -{ - "prompt": "The video is set in a modern, well-lit office environment with a sleek, minimalist design...", - "input_video_path" : "assets/example1_input_video.mp4", - "vis": { - "control_weight": 0.5, - "control_weight_prompt": "robotic arms . gloves" - }, - "edge": { - "control_weight": 0.5, - "control_weight_prompt": "robotic arms . gloves" - }, - "depth": { - "control_weight": 0.5 - }, - "seg": { - "control_weight": 0.5 - } -} -``` - -The output video can be found at `assets/example1_spatiotemporal_weights.mp4` and is shown below. - - - -The spatiotemporal mask extracted by the prompt `robotic arms . gloves` is show below. - - - -#### Explanation of the controlnet spec - -The controlnet spec is similar to Example 2 above, with the following exceptions: -* Additional `control_weight_prompt` for the vis and edge modalities. This will trigger the GroundingDINO+SAM2 pipeline to run video segmentation of the input video using `control_weight_prompt` (e.g. `robotic arms . gloves`) for `vis` and `edge` and extract a binarized spatiotemporal mask in which the positive pixels will have a `control_weight` of 0.5 (and negative pixels will have 0.0). -* Change the prompt section of the woman's clothing into a cream-colored and brown shirt. Since this area of the video will be conditioned only by `depth` and `seg`, there will be no conflict to the color information from `vis` modality. - -In effect, for the configuration given in `assets/inference_cosmos_transfer1_spatiotemporal_weights_auto.json`, `seg` and `depth` modalities will be applied everywhere uniformly, and `vis` and `edge` will be applied exclusively in the spatiotemporal mask given by the union of `robotic arms` and `gloves` mask detections. In those areas, the weight of each modality will be normalized to one, therefore `vis`, `edge`, `seg` and `depth` will be applied evenly there. - - -#### Example 4: batch generation -This example runs inference on a batch of prompts, provided through the `--batch_input_path` argument (path to a JSONL file). This enables running multiple generations with different prompts (and per-video control input customization) based on the same controlnet configurations. -Each line in the JSONL file must contain a `visual_input` field equivalent to the `--input_video_path` argument in the case of single control generation. It can also contain the `prompt` field. The batch system supports automatic control input generation, manual override of specific controls per video, and mixed usage of automatic and manual controls in the same batch. By default, the `input_control` specified within the controlnet spec json will be used for all samples in the batch, and are overridden if explicitly specified in the batch input json file (either with another `input_control` path or with `null` to indicate automatic generation based on the visual input). -Here is an example of the Batch Input JSONL Format -```json -{"visual_input": "path/to/video0.mp4", "prompt": "A detailed description..."} -{"visual_input": "path/to/video1.mp4", "prompt": "A detailed description...", "control_overrides": {"seg": {"input_control": "path/to/video1_seg.mp4"}, "depth": {"input_control": null}}} -{"visual_input": "path/to/video2.mp4", "prompt": "A detailed description...", "control_overrides": {"seg": {"input_control": "path/to/video2_seg.mp4"}, "depth": {"input_control": "path/to/video2_depth.mp4"}}} -``` -Inference command (with 9 input frames): -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example2_uniform_weights \ - --controlnet_specs assets/inference_cosmos_transfer1_uniform_weights.json \ - --offload_text_encoder_model \ - --batch_input_path path/to/batch_input_path.json \ - --num_gpus $NUM_GPU -``` - - -## Arguments - -| Parameter | Description | Default | -|-----------|-------------|---------| -| `--controlnet_specs` | A JSON describing the Multi-ControlNet config | JSON | -| `--checkpoint_dir` | Directory containing model weights | "checkpoints" | -| `--tokenizer_dir` | Directory containing tokenizer weights | "Cosmos-Tokenize1-CV8x8x8-720p" | -| `--input_video_path` | The path to the input video | None | -| `--video_save_name` | Output video filename for single video generation | "output" | -| `--video_save_folder` | Output directory for batch video generation | "outputs/" | -| `--prompt` | Text prompt for video generation. | "The video captures a stunning, photorealistic scene with remarkable attention to detail, giving it a lifelike appearance that is almost indistinguishable from reality. It appears to be from a high-budget 4K movie, showcasing ultra-high-definition quality with impeccable resolution." | -| `--negative_prompt` | Negative prompt for improved quality | "The video captures a game playing, with bad crappy graphics and cartoonish frames. It represents a recording of old outdated games. The lighting looks very fake. The textures are very raw and basic. The geometries are very primitive. The images are very pixelated and of poor CG quality. There are many subtitles in the footage. Overall, the video is unrealistic at all." | -| `--num_steps` | Number of diffusion sampling steps | 35 | -| `--guidance` | CFG guidance scale | 7.0 | -| `--sigma_max` | The level of partial noise added to the input video in the range [0, 80.0]. Any value equal or higher than 80.0 will result in not using the input video and providing the model with pure noise. | 70.0 | -| `--blur_strength` | The strength of blurring when preparing the control input for the vis controlnet. Valid values are 'very_low', 'low', 'medium', 'high', and 'very_high'. | 'medium' | -| `--canny_threshold` | The threshold for canny edge detection when preparing the control input for the edge controlnet. Lower threshold means more edges detected. Valid values are 'very_low', 'low', 'medium', 'high', and 'very_high'. | 'medium' | -| `--fps` | Output frames-per-second | 24 | -| `--seed` | Random seed | 1 | -| `--offload_text_encoder_model` | Offload text encoder after inference, used for low-memory GPUs | False | -| `--offload_guardrail_models` | Offload guardrail models after inference, used for low-memory GPUs | False | -| `--upsample_prompt` | Upsample prompt using prompt upsampler model | False | -| `--offload_prompt_upsampler` | Offload prompt upsampler models after inference, used for low-memory GPUs | False | - -Note: in order to run Cosmos on low-memory GPUs, you can use model offloading. This is accomplished by offloading the model from GPU memory after it has served its purpose to open space for the next model execution. - -Note: we support various aspect ratios, including 1:1 (960x960 for height and width), 4:3 (960x704), 3:4 (704x960), 16:9 (1280x704), and 9:16 (704x1280). If the input is not one of these five resolutions, it is first resized to one of them according to the nearest aspect ratio. - - - -### Examples 5: Robotics Augmentation Workflow: Mapping one robotics synthetic example to multiple realistic examples - -[Robotics Augmentation Workflow](/cosmos_transfer1/auxiliary/robot_augmentation/README.md) - - - - -## Safety Features - -The model uses a built-in safety guardrail system that cannot be disabled. Generating human faces is not allowed and will be blurred by the guardrail. - -## Prompting Instructions - -The input prompt is the most important parameter under the user's control when interacting with the model. Providing rich and descriptive prompts can positively impact the output quality of the model, whereas short and poorly detailed prompts can lead to subpar video generation. Here are some recommendations to keep in mind when crafting text prompts for the model: - -1. **Describe a single, captivating scene**: Focus on a single scene to prevent the model from generating videos with unnecessary shot changes. -2. **Limit camera control instructions**: The model doesn't handle prompts involving camera control well, as this feature is still under development. diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_4kupscaler.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_4kupscaler.md deleted file mode 100644 index d7967d7d2c2cc03784d6ff1fc92127ac08c67df1..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_4kupscaler.md +++ /dev/null @@ -1,133 +0,0 @@ -# Transfer Inference featuring 4k Upscaler - -## Install Cosmos-Transfer1 - -### Environment setup - -Please refer to the Inference section of [INSTALL.md](/INSTALL.md#inference) for instructions on environment setup. - -### Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [Llama-Guard-3-8B terms](https://huggingface.co/meta-llama/Llama-Guard-3-8B) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e): - -```bash -PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ -``` - -Note that this will require about 300GB of free storage. Not all these checkpoints will be used in every generation. - -5. The downloaded files should be in the following structure: - -``` -checkpoints/ -├── nvidia -│ │ -│ ├── Cosmos-Guardrail1 -│ │ ├── README.md -│ │ ├── blocklist/... -│ │ ├── face_blur_filter/... -│ │ └── video_content_safety_filter/... -│ │ -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── 4kupscaler_control.pt -│ │ └── config.json -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ └── Cosmos-Tokenize1-CV8x8x8-720p -│ ├── decoder.jit -│ ├── encoder.jit -│ ├── autoencoder.jit -│ └── mean_std.pt -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -├── IDEA-Research/... -└── meta-llama/... -``` - -## Run Example - -For a general overview of how to use the model see [this guide](/examples/inference_cosmos_transfer1_7b.md). - - -Ensure you are at the root of the repository before executing the following: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/inference_upscaler \ - --controlnet_specs assets/inference_upscaler.json \ - --num_steps 10 \ - --offload_text_encoder_model \ - --num_gpus $NUM_GPU -``` - -You can also choose to run the inference on multiple GPUs as follows: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0,1,2,3}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=4}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/inference_upscaler \ - --controlnet_specs assets/inference_upscaler.json \ - --num_steps 10 \ - --offload_text_encoder_model \ - --num_gpus $NUM_GPU -``` - -This launches `transfer.py` and configures the controlnets for inference according to `assets/inference_upscaler.json`: - -```json -{ - "input_video_path" : "assets/inference_upscaler_input_video.mp4", - "upscale": { - "control_weight": 0.5 - }, -} -``` - -### Explanation of the controlnet spec - -* `prompt` (optional) specifies the prompt for the upscaler. If no prompt is provided, a default prompt saying the video is high-quality is used. -* `input_video_path` specifies the input video -* `sigma_max` specifies the level of noise that should be added to the input video before feeding through the base model branch -* The `control_weight` parameter is a number within the range [0, 1] that controls how strongly the controlnet branch should affect the output of the model. The larger the value (closer to 1.0), the more strongly the generated video will adhere to the controlnet input. However, this rididity may come at a cost of quality. Lower (closer to 0) values would give more creative liberty to the model at the cost of reduced adherance. Usually a middleground value, say 0.5, yields optinal results. - -### The input and output videos - -The input video is a 1280 x 704 video generated by Cosmos-Transfer1-7B-Text2World: - - - -Here is what the model outputs, a high-resolution 3840 x 2112 video: - - diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_depth.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_depth.md deleted file mode 100644 index d1fe54d58a44af017e27b87075c2beed5a72c754..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_depth.md +++ /dev/null @@ -1,24 +0,0 @@ -# Transfer Inference Example: Single Control (Depth) - -Here is another simple example of using the Depth control. Many steps are similar to the [Edge example](/examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge). The main difference is to use `assets/inference_cosmos_transfer1_single_control_depth.json` as the `--controlnet_specs`: - -```bash -export CUDA_VISIBLE_DEVICES=0 -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -PYTHONPATH=$(pwd) python cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example1_single_control_depth \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_depth.json \ - --offload_text_encoder_model -``` -Same as the [Edge example](/examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge), the input video is a low-resolution 640 × 480 video. - - - -This will generate a 960 x 704 video that preserves the 3D spatial structure and scene depth from the input video while enhancing visual quality, detail, and realism. - - diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_regionwise_prompt.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_regionwise_prompt.md deleted file mode 100644 index 76dfdd3d8f267a55a22d0ccd362de74dcf067ce1..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_regionwise_prompt.md +++ /dev/null @@ -1,40 +0,0 @@ -# Transfer Inference Example: Multi Control with regionwise prompts - -Here is another simple example of using multi control with regionwise prompts. Many steps are similar to the [Edge example](/examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge). The main difference is to use `assets/regionalprompt_test/inference_cosmos_transfer1_single_control_regional_prompt_video_mask.json` as the `--controlnet_specs`: - -```bash -export CUDA_VISIBLE_DEVICES=0 -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -PYTHONPATH=$(pwd) python cosmos_transfer1/diffusion/inference/transfer.py \ - --video_save_folder outputs/example1_multi_control_regional_prompt_video_mask \ - --controlnet_specs assets/regionalprompt_test/inference_cosmos_transfer1_multi_control_regional_prompt_video_mask.json \ - --offload_text_encoder_model --sigma_max 80 --offload_guardrail_models -``` - -The input video is a 1280 × 704 video with other mask videos. - -Input Video - - -Environment Mask - - -Forklift Mask - - -Worker Mask - - -This will generate a 1280 x 704 video that preserves the 3D spatial structure and scene depth from the input video while enhancing visual quality, detail, and realism. - - diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_sample_av.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_sample_av.md deleted file mode 100644 index 6597edf28421342f10d0c3417096fc7b84398edd..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_sample_av.md +++ /dev/null @@ -1,172 +0,0 @@ -# Sample-AV Transfer - -## Install Cosmos-Transfer1 - -### Environment setup - -Please refer to the Inference section of [INSTALL.md](/INSTALL.md#inference) for instructions on environment setup. - -### Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [Llama-Guard-3-8B terms](https://huggingface.co/meta-llama/Llama-Guard-3-8B) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e): - -```bash -PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ -``` - -Note that this will require about 300GB of free storage. Not all these checkpoints will be used in every generation. - -5. The downloaded files should be in the following structure: - -``` -checkpoints/ -├── nvidia -│ │ -│ ├── Cosmos-Guardrail1 -│ │ ├── README.md -│ │ ├── blocklist/... -│ │ ├── face_blur_filter/... -│ │ └── video_content_safety_filter/... -│ │ -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── 4kupscaler_control.pt -│ │ └── config.json -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ └── Cosmos-Tokenize1-CV8x8x8-720p -│ ├── decoder.jit -│ ├── encoder.jit -│ ├── autoencoder.jit -│ └── mean_std.pt -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -├── IDEA-Research/... -└── meta-llama/... -``` - -## Run Example - -For a general overview of how to use the model see [this guide](/examples/inference_cosmos_transfer1_7b.md). - -This is an example of post-training Cosmos-Transfer1 using autonomous vehicle (AV) data. Here we provide two controlnets, `hdmap` and `lidar`, that allow transfering from those domains to the real world. - -Ensure you are at the root of the repository before executing the following: - -```bash -#!/bin/bash -export PROMPT="The video is captured from a camera mounted on a car. The camera is facing forward. The video showcases a scenic golden-hour drive through a suburban area, bathed in the warm, golden hues of the setting sun. The dashboard camera captures the play of light and shadow as the sun’s rays filter through the trees, casting elongated patterns onto the road. The streetlights remain off, as the golden glow of the late afternoon sun provides ample illumination. The two-lane road appears to shimmer under the soft light, while the concrete barrier on the left side of the road reflects subtle warm tones. The stone wall on the right, adorned with lush greenery, stands out vibrantly under the golden light, with the palm trees swaying gently in the evening breeze. Several parked vehicles, including white sedans and vans, are seen on the left side of the road, their surfaces reflecting the amber hues of the sunset. The trees, now highlighted in a golden halo, cast intricate shadows onto the pavement. Further ahead, houses with red-tiled roofs glow warmly in the fading light, standing out against the sky, which transitions from deep orange to soft pastel blue. As the vehicle continues, a white sedan is seen driving in the same lane, while a black sedan and a white van move further ahead. The road markings are crisp, and the entire setting radiates a peaceful, almost cinematic beauty. The golden light, combined with the quiet suburban landscape, creates an atmosphere of tranquility and warmth, making for a mesmerizing and soothing drive." -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_name output_video \ - --video_save_folder outputs/sample_av_multi_control \ - --prompt "$PROMPT" \ - --sigma_max 80 \ - --offload_text_encoder_model --is_av_sample \ - --controlnet_specs assets/sample_av_multi_control_spec.json \ - --num_gpus $NUM_GPU -``` - -You can also choose to run the inference on multiple GPUs as follows: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0,1,2,3}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=4}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_name output_video \ - --video_save_folder outputs/sample_av_multi_control \ - --prompt "$PROMPT" \ - --sigma_max 80 \ - --offload_text_encoder_model --is_av_sample \ - --controlnet_specs assets/sample_av_multi_control_spec.json \ - --num_gpus $NUM_GPU -``` - -This launches `transfer.py` and configures the controlnets for inference according to `assets/sample_av_multi_control_spec.json`: - -```json -{ - "hdmap": { - "control_weight": 0.3, - "input_control": "assets/sample_av_multi_control_input_hdmap.mp4" - }, - "lidar": { - "control_weight": 0.7, - "input_control": "assets/sample_av_multi_control_input_lidar.mp4" - } -} -``` - -Note that unlike other examples, here we chose to provide the input prompt and some other parameters through the command line arguments, as opposed to through the spec file. This flexibility allows abstracting out the fixed parameters in the spec file and vary the dynamic parameters through the command line. - -### Additional Toolkits -We provide the `cosmos-av-sample-toolkits` at https://github.com/nv-tlabs/cosmos-av-sample-toolkits. - -This toolkit includes: - -- 10 additional raw data samples (e.g., HDMap and LiDAR), along with scripts to preprocess and render them into model-compatible inputs. -- Rendering scripts for converting other datasets, such as the Waymo Open Dataset, into inputs compatible with Cosmos-Transfer1. - -### The input and output videos - -HDMap input control: - - - - -LiDAR input control: - - - - -Output video using HDMap and LiDAR: - - - -Feel free to experiment with more specs. For example, the command below only uses HDMap: - -```bash -export PROMPT="The video is captured from a camera mounted on a car. The camera is facing forward. The video showcases a scenic golden-hour drive through a suburban area, bathed in the warm, golden hues of the setting sun. The dashboard camera captures the play of light and shadow as the sun’s rays filter through the trees, casting elongated patterns onto the road. The streetlights remain off, as the golden glow of the late afternoon sun provides ample illumination. The two-lane road appears to shimmer under the soft light, while the concrete barrier on the left side of the road reflects subtle warm tones. The stone wall on the right, adorned with lush greenery, stands out vibrantly under the golden light, with the palm trees swaying gently in the evening breeze. Several parked vehicles, including white sedans and vans, are seen on the left side of the road, their surfaces reflecting the amber hues of the sunset. The trees, now highlighted in a golden halo, cast intricate shadows onto the pavement. Further ahead, houses with red-tiled roofs glow warmly in the fading light, standing out against the sky, which transitions from deep orange to soft pastel blue. As the vehicle continues, a white sedan is seen driving in the same lane, while a black sedan and a white van move further ahead. The road markings are crisp, and the entire setting radiates a peaceful, almost cinematic beauty. The golden light, combined with the quiet suburban landscape, creates an atmosphere of tranquility and warmth, making for a mesmerizing and soothing drive." -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_name output_video \ - --video_save_folder outputs/sample_av_hdmap_spec \ - --prompt "$PROMPT" \ - --offload_text_encoder_model --is_av_sample \ - --controlnet_specs assets/sample_av_hdmap_spec.json \ - --num_gpus $NUM_GPU -``` diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_sample_av_single2multiview.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_sample_av_single2multiview.md deleted file mode 100644 index 982c460ff5471e65e837505c26099c0f06739633..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_sample_av_single2multiview.md +++ /dev/null @@ -1,148 +0,0 @@ -# Transfer1 Sample-AV Single2Multiview Inference Example - -## Install Cosmos-Transfer1 - -### Environment setup - -Please refer to the Inference section of [INSTALL.md](/INSTALL.md#inference) for instructions on environment setup. - -### Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [LlamaGuard-7b terms](https://huggingface.co/meta-llama/LlamaGuard-7b) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e): - -```bash -PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ --model 7b_av -``` - -Note that this will require about 300GB of free storage. Not all these checkpoints will be used in every generation. - -5. The downloaded files should be in the following structure: - -``` -checkpoints/ -├── nvidia -│ │ -│ ├── Cosmos-Guardrail1 -│ │ ├── README.md -│ │ ├── blocklist/... -│ │ ├── face_blur_filter/... -│ │ └── video_content_safety_filter/... -│ │ -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── 4kupscaler_control.pt -│ │ └── config.json -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/ -│ │ ├── t2w_base_model.pt -│ │ ├── t2w_hdmap_control.pt -│ │ ├── t2w_lidar_control.pt -│ │ ├── v2w_base_model.pt -│ │ ├── v2w_hdmap_control.pt -│ │ └── v2w_lidar_control.pt -│ │ -│ └── Cosmos-Tokenize1-CV8x8x8-720p -│ ├── decoder.jit -│ ├── encoder.jit -│ ├── autoencoder.jit -│ └── mean_std.pt -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -├── IDEA-Research/... -└── meta-llama/... -``` - -## Run Example - -For a general overview of how to use the model see [this guide](/examples/inference_cosmos_transfer1_7b.md). - -This is an example of running Cosmos-Transfer1-Sample-AV-Single2Multiview using autonomous vehicle (AV) data. Here we provide multiview `hdmap` as conditioning, transferring virtual worlds demarcated by map elements to the real world. - -Ensure you are at the root of the repository before executing the following to launch `transfer_multiview.py` and configures the controlnets for inference according to `assets/sample_av_hdmap_multiview_spec.json`: - -```bash -#!/bin/bash -export PROMPT="The video is captured from a camera mounted on a car. The camera is facing forward. The video captures a driving scene on a multi-lane highway during the day. The sky is clear and blue, indicating good weather conditions. The road is relatively busy with several cars and trucks in motion. A red sedan is driving in the left lane, followed by a black pickup truck in the right lane. The vehicles are maintaining a safe distance from each other. On the right side of the road, there are speed limit signs indicating a limit of 65 mph. The surrounding area includes a mix of greenery and industrial buildings, with hills visible in the distance. The overall environment appears to be a typical day on a highway with moderate traffic. The golden light of the late afternoon bathes the highway, casting long shadows and creating a warm, serene atmosphere. The sky is a mix of orange and blue, with the sun low on the horizon. The red sedan in the left lane reflects the golden hues, while the black pickup truck in the right lane casts a distinct shadow on the pavement. The speed limit signs stand out clearly under the fading sunlight. The surrounding greenery glows with a rich, warm tone, and the industrial buildings take on a softened appearance in the sunset." -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPUS=1 -PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_transfer1/diffusion/inference/transfer_multiview.py \ ---checkpoint_dir $CHECKPOINT_DIR \ ---video_save_name output_video_1_30_0 \ ---video_save_folder outputs/sample_av_multiview \ ---offload_text_encoder_model \ ---guidance 3 \ ---controlnet_specs assets/sample_av_hdmap_multiview_spec.json --num_gpus ${NUM_GPUS} --num_steps 30 \ ---view_condition_video assets/sample_av_mv_input_rgb.mp4 \ ---prompt "$PROMPT" -``` - -We can further extend the video we've just generated with the Cosmos-Transfer1-Sample-AV-Single2Multiview-Video2World model using this command: - -```bash -#!/bin/bash -export PROMPT="The video is captured from a camera mounted on a car. The camera is facing forward. The video captures a driving scene on a multi-lane highway during the day. The sky is clear and blue, indicating good weather conditions. The road is relatively busy with several cars and trucks in motion. A red sedan is driving in the left lane, followed by a black pickup truck in the right lane. The vehicles are maintaining a safe distance from each other. On the right side of the road, there are speed limit signs indicating a limit of 65 mph. The surrounding area includes a mix of greenery and industrial buildings, with hills visible in the distance. The overall environment appears to be a typical day on a highway with moderate traffic. The golden light of the late afternoon bathes the highway, casting long shadows and creating a warm, serene atmosphere. The sky is a mix of orange and blue, with the sun low on the horizon. The red sedan in the left lane reflects the golden hues, while the black pickup truck in the right lane casts a distinct shadow on the pavement. The speed limit signs stand out clearly under the fading sunlight. The surrounding greenery glows with a rich, warm tone, and the industrial buildings take on a softened appearance in the sunset." -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPUS=1 -PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_transfer1/diffusion/inference/transfer_multiview.py \ ---checkpoint_dir $CHECKPOINT_DIR \ ---video_save_name output_video_extension \ ---video_save_folder outputs/sample_av_multiview \ ---offload_text_encoder_model \ ---guidance 3 \ ---controlnet_specs assets/sample_av_hdmap_multiview_lvg_spec.json --num_gpus ${NUM_GPUS} --num_steps 30 \ ---view_condition_video assets/sample_av_mv_input_rgb.mp4 \ ---prompt "$PROMPT" \ ---n_clip_max 3 --num_input_frames 9 --initial_condition_video outputs/sample_av_multiview/output_video.mp4 -``` -Video extension is achieved by looping the Cosmos-Transfer1-Sample-AV-Single2Multiview model to generate multiple 57-frame clips. Three additional arguments are provided to enable video extension: -1. `--n_clip_max` control the number of clips. it does not control the number of frames generated per clip. The model cannot generate more frames than that is present in `--view_contion_video`. -2. `--num_input_frames` controls the number of overlapping frames between each clip, creating smooth transition between clips. This can be set to either `1` or `9`. -3. `--initial_condition_video` is the video generated in the first example using the `t2w` model. - -We also provide `lidar` controled examples that can be tested by modifying the `--controlnet_specs` to `assets/sample_av_lidar_multiview_spec.json` in the above commands. - -## Run Post-trained Example -If you follow the post-training example in [Training README](./training_cosmos_transfer_7B_sample_AV.md), you will eventually end up with a waymo-style post-trained ckpt where there are 5 input and output views. The inference scirpt is a little bit different than the pre-trained 6 view model. We provided an example of running Cosmos-Transfer1-Sample-AV-Single2Multiview post-trained with waymo data. Here we provide multiview `hdmap` as conditioning, transferring virtual worlds demarcated by map elements to the real world. - -Ensure you are at the root of the repository before executing the following to launch `transfer_multiview.py` and configures the controlnets for inference according to `assets/sample_av_hdmap_multiview_waymo_spec.json`, the ckpt_path need to match your own post-trained ckpt : - -```bash -#!/bin/bash -export PROMPT="The video is captured from a camera mounted on a car. The camera is facing forward. The video captures a driving scene on a multi-lane highway during the day. The sky is clear and blue, indicating good weather conditions. The road is relatively busy with several cars and trucks in motion. A red sedan is driving in the left lane, followed by a black pickup truck in the right lane. The vehicles are maintaining a safe distance from each other. On the right side of the road, there are speed limit signs indicating a limit of 65 mph. The surrounding area includes a mix of greenery and industrial buildings, with hills visible in the distance. The overall environment appears to be a typical day on a highway with moderate traffic. The golden light of the late afternoon bathes the highway, casting long shadows and creating a warm, serene atmosphere. The sky is a mix of orange and blue, with the sun low on the horizon. The red sedan in the left lane reflects the golden hues, while the black pickup truck in the right lane casts a distinct shadow on the pavement. The speed limit signs stand out clearly under the fading sunlight. The surrounding greenery glows with a rich, warm tone, and the industrial buildings take on a softened appearance in the sunset." -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPUS=1 -PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_transfer1/diffusion/inference/transfer_multiview.py \ ---checkpoint_dir $CHECKPOINT_DIR \ ---video_save_name output_video \ ---video_save_folder outputs/sample_av_multiview_waymo \ ---offload_text_encoder_model \ ---guidance 3 \ ---controlnet_specs assets/sample_av_hdmap_multiview_spec.json --num_gpus ${NUM_GPUS} --num_steps 30 \ ---view_condition_video assets/sample_av_mv_input_rgb.mp4 \ ---prompt "$PROMPT" ---waymo_example True -``` diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_seg.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_seg.md deleted file mode 100644 index 20153719d0356c7d9d7f0b8dcd86e60155a3cab2..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_seg.md +++ /dev/null @@ -1,27 +0,0 @@ -# Transfer Inference Example: Single Control (Segmentation) - -Here is another simple example of using the Segmentation control. Many steps are similar to the [Edge example](/examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge). The main difference is to use `assets/inference_cosmos_transfer1_single_control_seg.json` as the `--controlnet_specs`: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example1_single_control_seg \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_seg.json \ - --offload_text_encoder_model \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` -Same as the [Edge example](/examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge), the input video is a low-resolution 640 × 480 video. - - - -This will generate a 960 x 704 video that maintains the structural consistency from the input video while enhancing visual quality, detail, and realism. - - diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_vis.md b/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_vis.md deleted file mode 100644 index 1e0b19298bb4e676c798aa09d7cf0385dac9cff7..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1_7b_vis.md +++ /dev/null @@ -1,27 +0,0 @@ -# Transfer Inference Example: Single Control (Vis) - -Here is another simple example of using the Vis control. Many steps are similar to the [Edge example](/examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge). The main difference is to use `assets/inference_cosmos_transfer1_single_control_vis.json` as the `--controlnet_specs`: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=1}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/example1_single_control_vis \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_vis.json \ - --offload_text_encoder_model \ - --offload_guardrail_models \ - --num_gpus $NUM_GPU -``` -Same as the [Edge example](/examples/inference_cosmos_transfer1_7b.md#example-1-single-control-edge), the input video is a low-resolution 640 × 480 video. - - - -This will generate a 960 x 704 video that preserves the overall color palette, lighting, and coarse structure from the input video using the vis control. By guiding the generation with a blurred version of the input, the model maintains the original scene's visual feel while significantly enhancing visual quality, detail, and realism based on the provided prompt. - - diff --git a/cosmos-transfer1/examples/inference_cosmos_transfer1pt1_7b_keypoint.md b/cosmos-transfer1/examples/inference_cosmos_transfer1pt1_7b_keypoint.md deleted file mode 100644 index 62ff6d8f796630f55cb748f9dbdf193a38b7eede..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/inference_cosmos_transfer1pt1_7b_keypoint.md +++ /dev/null @@ -1,142 +0,0 @@ -# Cosmos-Transfer1: Inference featuring keypoint control - -## Install Cosmos-Transfer1 - -### Environment setup - -Please refer to the Inference section of [INSTALL.md](/INSTALL.md#inference) for instructions on environment setup. - -### Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [Llama-Guard-3-8B terms](https://huggingface.co/meta-llama/Llama-Guard-3-8B) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e): - -```bash -PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ -``` - -Note that this will require about 300GB of free storage. Not all these checkpoints will be used in every generation. - -5. The downloaded files should be in the following structure: - -``` -checkpoints/ -├── nvidia -│ │ -│ ├── Cosmos-Guardrail1 -│ │ ├── README.md -│ │ ├── blocklist/... -│ │ ├── face_blur_filter/... -│ │ └── video_content_safety_filter/... -│ │ -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── keypoint_control.ptg -│ │ ├── 4kupscaler_control.pt -│ │ └── config.json -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ │── Cosmos-Tokenize1-CV8x8x8-720p -│ │ ├── decoder.jit -│ │ ├── encoder.jit -│ │ ├── autoencoder.jit -│ │ └── mean_std.pt -│ │ -│ └── Cosmos-UpsamplePrompt1-12B-Transfer -│ ├── depth -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── README.md -│ ├── segmentation -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── seg_upsampler_example.png -│ └── viscontrol -│ ├── consolidated.safetensors -│ ├── params.json -│ └── tekken.json -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -├── IDEA-Research/... -└── meta-llama/... -``` - -## Run Example - -For a general overview of how to use the model see [this guide](inference_cosmos_transfer1_7b.md). - -Ensure you are at the root of the repository before executing the following: - -```bash -export CUDA_VISIBLE_DEVICES=0 -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -PYTHONPATH=$(pwd) python cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/inference_keypoint \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_keypoint.json \ - --offload_text_encoder_model -``` - -You can also choose to run the inference on multiple GPUs as follows: - -```bash -export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0,1,2,3}" -export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}" -export NUM_GPU="${NUM_GPU:=4}" -PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 cosmos_transfer1/diffusion/inference/transfer.py \ - --checkpoint_dir $CHECKPOINT_DIR \ - --video_save_folder outputs/inference_keypoint \ - --controlnet_specs assets/inference_cosmos_transfer1_single_control_keypoint.json \ - --offload_text_encoder_model \ - --num_gpus $NUM_GPU -``` - -This launches `transfer.py` and configures the controlnets for inference according to `assets/inference_keypoint_input_video.json`: - -```json -{ - "prompt": "The video takes place in a kitchen setting ...", - "input_video_path": "assets/inference_keypoint_input_video.mp4", - "keypoint": { - "control_weight": 1.0 - } -} -``` - -### The input and output videos - -The input video looks like this: - - - - -Here's what the model outputs: - - - -Note that the faces in the generated video have been blurred by the guardrail. diff --git a/cosmos-transfer1/examples/process_control_input_data_for_training.md b/cosmos-transfer1/examples/process_control_input_data_for_training.md deleted file mode 100644 index 67df46b066d3bce6ccf5d978b55068af13c809cf..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/process_control_input_data_for_training.md +++ /dev/null @@ -1,108 +0,0 @@ -# Processing Control Input Data for Training - -This document provides detailed information about preparing control input data for training different Cosmos-Transfer1 models. - -## DepthControl Training Data Format - -- Requires depth videos in MP4 format -- Must be frame-wise aligned with corresponding RGB videos, and has same [H, W] dimensions as the input videos. -- Place in `depth/` directory - -## SegControl Training Data Format - -The segmentation data is stored in pickle files, one per video. After loading a pickle file, the data structure is as follows: - -```python -[ - { # First detected object - 'phrase': str, # Name/description of the detected object - 'segmentation_mask_rle': { - 'data': bytes, # Run-length encoded binary mask data - 'mask_shape': tuple # Shape of the mask (height, width) - } - }, - { # Second detected object - 'phrase': str, - 'segmentation_mask_rle': { - 'data': bytes, - 'mask_shape': tuple - } - }, - # ... more detected objects -] -``` - -#### Key Components: - -1. **Object Detection**: - - List of dictionaries, one per detected object - - Each detection contains: - - `phrase`: String describing the object - - `segmentation_mask_rle`: Dictionary containing: - - `data`: RLE-encoded binary mask data - - `mask_shape`: Tuple specifying mask dimensions (height, width) - -2. **Mask Creation**: - - Reference implementation in `cosmos_transfer1/auxiliary/sam2/sam2_model.py` - - -## KeypointControl Training Data Format - -For training KeypointControl models, you need to provide a pickle file containing 2D human keypoint annotations for each frame. The pickle file should follow this structure: - -```python -{ - frame_id: [ # List of detected humans in this frame - { # Annotation for one human - 'human-bbox': np.array([x1, y1, x2, y2, confidence], dtype=np.float16), # Normalized coordinates - 'human-bbox-abs': np.array([x1, y1, x2, y2, confidence], dtype=np.float16), # Absolute coordinates - 'body-keypoints': np.array([[x, y, confidence], ...], dtype=np.float16), # Shape: [133, 3], in the COCO-Wholebody format, normalized coordinates - 'body-keypoints-abs': np.array([[x, y, confidence], ...], dtype=np.float16), # Shape: [133, 3], in the COCO-Wholebody format, absolute coordinates - 'hand-keypoints': np.array([[x, y, confidence], ...], dtype=np.float16), # Shape: [42, 3], relative coordinates. It's a duplicate of the [91:133]-th keypoints of the 'body-keypoints' - 'face-bbox': np.array([x1, y1, width, height], dtype=np.float16), # Normalized coordinates of the face bounding boxes of the humans detected - 'face-bbox-abs': np.array([x1, y1, width, height], dtype=np.int16) # Absolute coordinates of the face bounding boxes of the humans detected - }, - # ... more humans in this frame - ], - # ... more frames -} -``` - -### Key Components: - -1. **Frame ID**: - - Key in the dictionary - - Should match the corresponding video frame - -2. **Per-Human Detection**: - - List of dictionaries, one per detected human - - Each detection contains: - - Bounding boxes (normalized and absolute) - - Body keypoints (133 points) - - Hand keypoints (42 points) - - Face bounding box - -3. **Coordinate Systems**: - - Normalized coordinates: Values between 0 and 1 - - Absolute coordinates: Pixel coordinates in the image - - All coordinates follow [x, y] format - -4. **Confidence Scores**: - - Included for each keypoint and bounding box - - Values between 0 and 1 - - Higher values indicate more reliable detections - -### Data Preparation Tips: - -1. **Keypoint Detection**: - - We used `rtmlib` for human keypoint detection and output the COCO-Wholebody keypoint convention. - -2. **File Organization**: - - Name the pickle file to match the video name - - Place in the `keypoint/` directory - - Ensure frame IDs match video frames - -## VisControl and EdgeControl -- These are self-supervised -- No separate data preparation needed -- Control inputs are generated on-the-fly during training. diff --git a/cosmos-transfer1/examples/training_cosmos_transfer_7B_sample_AV.md b/cosmos-transfer1/examples/training_cosmos_transfer_7B_sample_AV.md deleted file mode 100644 index 2e3ef0d4cffdb72697371b1497cb5381cbee0c90..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/training_cosmos_transfer_7B_sample_AV.md +++ /dev/null @@ -1,224 +0,0 @@ -# Training Cosmos-Transfer1-Sample-AV Models -In this document, we provide examples and steps to: -- Build your own Cosmos-Transfer1-Sample-AV models, training from scratch; or -- Post-train Cosmos-Transfer1-Sample-AV models from our checkpoint using your data. - -The model is trained separately for each control input type. - - -## Model Support Matrix -We support the following Cosmos-Transfer1-Sample-AV models for pre-training and post-training. Review the available models and their compute requirements for training to determine the best model for your use case. We use Tensor Parallel of size 8 for training. - -| Model Name | Model Status | Compute Requirements for Post-Training | -|-------------------------------------------------------------------|--------------|----------------------------------------| -| Cosmos-Transfer1-7B-Sample-AV [Lidar] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B-Sample-AV [HDMap] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/t2w_model [Lidar] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/t2w_model [HDMap] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/v2w_model [Lidar] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/v2w_model [HDMap] | **Supported**| 8 NVIDIA GPUs* | - -**\*** 80GB GPU memory required for training. `H100-80GB` or `A100-80GB` GPUs are recommended. - -## Environment setup - -Please refer to the training section of [INSTALL.md](/INSTALL.md#post-training) for instructions on environment setup. - -## Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [LlamaGuard-7b terms](https://huggingface.co/meta-llama/LlamaGuard-7b) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e). Note that this will require about 300GB of free storage. - -```bash -PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ -``` - -5. The downloaded files should be in the following structure. - -``` -checkpoints/ -├── nvidia -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── keypoint_control.pt -│ │ ├── 4kupscaler_control.pt -│ │ ├── config.json -│ │ └── guardrail -│ │ ├── aegis/ -│ │ ├── blocklist/ -│ │ ├── face_blur_filter/ -│ │ └── video_content_safety_filter/ -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/ -│ │ ├── v2w_base_model.pt -│ │ ├── v2w_hdmap_control.pt -│ │ ├── v2w_lidar_control.pt -│ │ ├── t2w_base_model.pt -│ │ ├── t2w_hdmap_control.pt -│ │ └── t2w_lidar_control.pt -│ │ -│ │── Cosmos-Tokenize1-CV8x8x8-720p -│ │ ├── decoder.jit -│ │ ├── encoder.jit -│ │ ├── autoencoder.jit -│ │ └── mean_std.pt -│ │ -│ └── Cosmos-UpsamplePrompt1-12B-Transfer -│ ├── depth -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── README.md -│ ├── segmentation -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── seg_upsampler_example.png -│ └── viscontrol -│ ├── consolidated.safetensors -│ ├── params.json -│ └── tekken.json -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -└── IDEA-Research/ -``` - -Checkpoint Requirements: -- Base model (`base_model.pt` for single view, `t2w_base_model.pt`, `v2w_base_model.pt` for multiview) and tokenizer models (under `Cosmos-Tokenize1-CV8x8x8-720p`): Required for all training. -- Control modality-specific model checkpoint (e.g., `hdmap_control.pt`): Only needed for post-training that specific control. Not needed if training from scratch. -- Other folders such as `depth-anything`, `facebook/sam2-hiera-large` etc.: optional. These are helper modules to process the video data into the respective control modalities such as depth and segmentation. - -## Examples -There are 3 steps to train a Cosmos-Transfer1-Sample-AV model: preparing a dataset, prepare checkpoints, and launch training. - -### 1. Dataset Download and Preprocessing -In the example below, we use a subset of [Waymo Open Dataset](https://waymo.com/open/) dataset to demonstrate the steps for preparing the data and launching training. -Please follow the [instructions](https://github.com/nv-tlabs/cosmos-av-sample-toolkits/blob/main/docs/processing_waymo_for_transfer1.md) in [cosmos-av-sample-toolkits](https://github.com/nv-tlabs/cosmos-av-sample-toolkits) to download and convert the Waymo Open Dataset. - - -### 2. Splitting the Checkpoints to TensorParallel Checkpoints -Due to the large model size, we leverage TensorParallel (TP) to split the model weights across multiple GPUs. We use 8 for the TP size. - -```bash -# Will split the Base model checkpoint into 8 TP checkpoints -PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B-Sample-AV/t2w_base_model.pt -# Example: for LidarControl checkpoint splitting for post-train. -PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B-Sample-AV/t2w_lidar_control.pt - -# Example: for Single2MultiView, the base model checkpoint is different -PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/t2w_base_model.pt -# Example: for Single2MultiView HDMapControl -PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B-Sample-AV-Single2MultiView/t2w_hdmap_control.pt -``` -This will generate the TP checkpoints under `checkpoints/checkpoints_tp/*_mp_*.pt`, which we load in the training below. - -### 3. (Optional): Dry-run a Training Job -As a sanity check, run the following command to dry-run an example training job with the above data. The command will generated a full configuration of the experiment. - -```bash -export OUTPUT_ROOT=checkpoints # default value - -# Training from scratch -torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_pretrain - -# Post-train from our provided checkpoint (need to first split checkpoint into TP checkpoints as instructed above) -torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_posttrain -``` - -Explanation of the command: - -- The trainer and the passed (master) config script will, in the background, load the detailed experiment configurations defined in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av.py`, and register the experiments configurations for all `hint_keys` (control modalities), covering both pretrain and post-train. We use [Hydra](https://hydra.cc/docs/intro/) for advanced configuration composition and overriding. - -- The `CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_pretrain` corresponds to an experiment name registered in `ctrl_7b_tp_sample_av.py`. By specifiying this name, all the detailed config will be generated and then written to `checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_sampleAV/CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_pretrain/config.yaml`. - -- To customize your training, see `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av.py` to understand how the detailed configs of the model, trainer, dataloader etc. are defined, and edit as needed. - -### 4. Launch Training - -#### 4.a Launch Training of Cosmos-Transfer1-7B-Sample-AV -Now we can start a real training job! Removing the `--dryrun` and set `--nproc_per_node=8` will start a real training job on 8 GPUs, using Lidar conditioning: - -```bash -torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_pretrain -``` -#### 4.b Launch Training of Cosmos-Transfer1-7B-Sample-AV-Single2MultiView -In this example, we instead launch a training run of the Single2MultiView model with HDMap condition: - -```bash -torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_t2w_sv2mv_57frames_control_input_hdmap_block3_pretrain -``` - -**Config group and override.** An `experiment` determines a complete group of configuration parameters (model architecture, data, trainer behavior, checkpointing, etc.). Changing the `experiment` value in the command above will decide which ControlNet model is trained, and whether it's pretrain or post-train. For example, replacing the experiment name in the command with `CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_posttrain` will post-train the LidarControl model from the downloaded checkpoint instead. - -To customize your training, see the job (experiment) config in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_sample_av.py` to understand how they are defined, and edit as needed. - -It is also possible to modify config parameters from the command line. For example: - -```bash -torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_pretrain trainer.max_iter=100 checkpoint.save_iter=40 -``` - -This will update the maximum training iterations to 100 (default in the registered experiments: 999999999) and checkpoint saving frequency to 40 (default: 1000). - -**Saving Checkpoints and Resuming Training.** -During training, the checkpoints will be saved in the structure below. Since we use TensorParallel across 8 GPUs, 8 checkpoints will be saved each time. - -``` -checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_sampleAV/CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_pretrain/checkpoints/ -├── iter_{NUMBER}.pt # "master" checkpoint, saving metadata only -├── iter_{NUMBER}_model_mp_0.pt # real TP checkpoints -├── iter_{NUMBER}_model_mp_1.pt -├── ... -├── iter_{NUMBER}_model_mp_7.pt -``` - -Since the `experiment` is uniquely associated with its checkpoint directory, rerunning the same training command after an unexpected interruption will automatically resume from the latest saved checkpoint. - -### 5. Inference Using Trained Models - -**Converting the TP checkpoints to FSDP checkpoint:** To convert Tensor Parallel (TP) checkpoints to Fully Sharded Data Parallel (FSDP) format, use the conversion script `convert_ckpt_tp_to_fsdp.py`. This script requires the same number of GPUs as your TP size (e.g., if you trained with TP_SIZE=8, you need 8 GPUs for conversion). - -Example usage for Sample-AV models: -```bash -# For single-view models -torchrun --nproc_per_node=8 convert_ckpt_tp_to_fsdp.py \ - --experiment CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_posttrain \ - --checkpoint-path checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_sampleAV/CTRL_7Bv1pt3_t2w_121frames_control_input_lidar_block3_posttrain/checkpoints/iter_000000100.pt - -# For SingleToMultiView models -torchrun --nproc_per_node=8 convert_ckpt_tp_to_fsdp.py \ - --experiment CTRL_7Bv1pt3_t2w_sv2mv_57frames_control_input_hdmap_block3_posttrain \ - --checkpoint-path checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_sampleAV/CTRL_7Bv1pt3_t2w_sv2mv_57frames_control_input_hdmap_block3_posttrain/checkpoints/iter_000000100.pt -``` - -Optional arguments: -- `--output-directory`: Custom directory for saving FSDP checkpoints (default: automatically generated from checkpoint path) -- `--include-base-model`: Include base model in ControlNet checkpoint (default: False) - -The script will create two files in the output directory: -1. `*_reg_model.pt`: Regular model checkpoint -2. `*_ema_model.pt`: EMA model checkpoint - -The EMA model checkpoint (`*_ema_model.pt`) typically presents better quality results and is recommended for running inference in the next stage. For more details about the conversion process and available options, refer to the script's docstring. - -**Run inference:** Follow the steps in the [inference README](./inference_cosmos_transfer1_7b_sample_av.md). diff --git a/cosmos-transfer1/examples/training_cosmos_transfer_7b.md b/cosmos-transfer1/examples/training_cosmos_transfer_7b.md deleted file mode 100644 index df67a56fccca1ade8a4791508b182adf95a03daa..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/examples/training_cosmos_transfer_7b.md +++ /dev/null @@ -1,274 +0,0 @@ -# Training Cosmos-Transfer1 Models -In this document, we provide examples and steps to: -- Build your own Cosmos-Transfer1 models, training from scratch; or -- Post-train Cosmos-Transfer1 models from our checkpoint using your data. - -The model is trained separately for each control input type. - - -## Model Support Matrix -We support the following Cosmos-Transfer1 models for pre-training and post-training. Review the available models and their compute requirements for training to determine the best model for your use case. We use Tensor Parallel of size 8 for training. - -| Model Name | Model Status | Compute Requirements for Post-Training | -|------------------------------------------|--------------|----------------------------------------| -| Cosmos-Transfer1-7B [Depth] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B [Edge] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B [Keypoint] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B [Segmentation] | **Supported**| 8 NVIDIA GPUs* | -| Cosmos-Transfer1-7B [Vis] | **Supported**| 8 NVIDIA GPUs* | - -**\*** 80GB GPU memory required for training. `H100-80GB` or `A100-80GB` GPUs are recommended. - -## Environment setup - -Please refer to the training section of [INSTALL.md](/INSTALL.md#post-training) for instructions on environment setup. - -## Download Checkpoints - -1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained'). - -2. Log in to Hugging Face with the access token: - -```bash -huggingface-cli login -``` - -3. Accept the [LlamaGuard-7b terms](https://huggingface.co/meta-llama/LlamaGuard-7b) - -4. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-transfer1-67c9d328196453be6e568d3e). Note that this will require about 300GB of free storage. - -```bash -PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir checkpoints/ -``` - -5. The downloaded files should be in the following structure. - -``` -checkpoints/ -├── nvidia -│ ├── Cosmos-Transfer1-7B -│ │ ├── base_model.pt -│ │ ├── vis_control.pt -│ │ ├── edge_control.pt -│ │ ├── seg_control.pt -│ │ ├── depth_control.pt -│ │ ├── keypoint_control.pt -│ │ ├── 4kupscaler_control.pt -│ │ ├── config.json -│ │ └── guardrail -│ │ ├── aegis/ -│ │ ├── blocklist/ -│ │ ├── face_blur_filter/ -│ │ └── video_content_safety_filter/ -│ │ -│ ├── Cosmos-Transfer1-7B-Sample-AV/ -│ │ ├── base_model.pt -│ │ ├── hdmap_control.pt -│ │ └── lidar_control.pt -│ │ -│ │── Cosmos-Tokenize1-CV8x8x8-720p -│ │ ├── decoder.jit -│ │ ├── encoder.jit -│ │ ├── autoencoder.jit -│ │ └── mean_std.pt -│ │ -│ └── Cosmos-UpsamplePrompt1-12B-Transfer -│ ├── depth -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── README.md -│ ├── segmentation -│ │ ├── consolidated.safetensors -│ │ ├── params.json -│ │ └── tekken.json -│ ├── seg_upsampler_example.png -│ └── viscontrol -│ ├── consolidated.safetensors -│ ├── params.json -│ └── tekken.json -│ -├── depth-anything/... -├── facebook/... -├── google-t5/... -└── IDEA-Research/ -``` - -Checkpoint Requirements: -- Base model (`base_model.pt`) and tokenizer models (under `Cosmos-Tokenize1-CV8x8x8-720p`): Required for all training. -- Control modality-specific model checkpoint (e.g., `seg_control.pt`): Only needed for post-training that specific control. Not needed if training from scratch. -- Other folders such as `depth-anything`, `facebook/sam2-hiera-large` etc.: optional. These are helper modules to process the video data into the respective control modalities such as depth and segmentation. - -## Example -There are 3 steps to train a Cosmos-Transfer1 model: preparing a dataset, prepare checkpoints, and launch training. - -In the example below, we use a subset of [HD-VILA-100M](https://github.com/microsoft/XPretrain/tree/main/hd-vila-100m) dataset to demonstrate the steps for preparing the data and launching training. After preprocessing, your dataset directory should be structured as follows: -``` -datasets/hdvila/ -├── metas/ -│ ├── *.json -│ ├── *.txt -├── videos/ -│ ├── *.mp4 -├── t5_xxl/ -│ ├── *.pickle -├── keypoint/ -│ ├── *.pickle -├── depth/ -│ ├── *.mp4 -├── seg/ -│ ├── *.pickle -└── / - ├── -``` - -File naming must be consistent across modalities. For example, to train a SegControl model with a video named `videos/example1.mp4`, the corresponding annotation files should be: `seg/example1.pickle`. - -Note: Only the folder corresponding to your chosen control input modality is required. For example, if you're training with depth as the control input, only the `depth/` subfolder is needed. - -### 1. Prepare Videos and Captions - -The first step is to prepare a dataset with videos and captions. You must provide a folder containing a collection of videos in **MP4 format**, preferably 720p. These videos should focus on the subject throughout the entire video so that each video chunk contains the subject. - -Here we use a subset of sample videos from HD-VILA-100M as an example: - -```bash -# Download metadata with video urls and captions -mkdir -p datasets/hdvila -cd datasets/hdvila -wget https://huggingface.co/datasets/TempoFunk/hdvila-100M/resolve/main/hdvila-100M.jsonl -``` - -Run the following command to download the sample videos used for training: - -```bash -# Requirements for Youtube video downloads & video clipping -pip install pytubefix ffmpeg -``` - -```bash -# The script will downlaod the original HD-VILA-100M videos, save the corresponding clips, the captions and the metadata. -PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip -``` - -### 2. Computing T5 Text Embeddings -Run the following command to pre-compute T5-XXL embeddings for the video captions used for training: - -```bash -# The script will read the captions, save the T5-XXL embeddings in pickle format. -PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila -``` - -### 3. Obtaining the Control Input Data -Next, we generate the control input data corresponding to each video. If you already have accurate control input data (e.g., ground truth depth, segmentation masks, or human keypoints), you can skip this step -- just ensure your files are organized in the above structure, and follow the data format as detailed in [Process Control Input Data](process_control_input_data_for_training.md). - -Here, as an example, we show show how to obtain the control input signals from the input RGB videos. Specifically: - -- DepthControl requires a depth video that is frame-wise aligned with the corresponding RGB video. This can be obtained by, for example, running [DepthAnythingV2](https://github.com/DepthAnything/Depth-Anything-V2) on the input videos. - -- SegControl requires a `.pickle` file in the SAM2 output format containing per-frame segmentation masks. See [Process Control Input Data](process_control_input_data_for_training.md) for detailed format requirements. - -- KeypointControl requires a `.pickle` file containing 2D human keypoint annotations for each frame. See [Process Control Input Data](process_control_input_data_for_training.md) for detailed format requirements. - -For VisControl and EdgeControl models: training is self-supervised. These models get control inputs (e.g., by applying blur or extracting Canny edges) from the input videos on-the-fly during training. Therefore, you do not need to prepare control input data separately for these modalities. - - -### 4. Splitting the Checkpoints to TensorParallel Checkpoints -Due to the large model size, we leverage TensorParallel (TP) to split the model weights across multiple GPUs. We use 8 for the TP size. - -```bash -# Will split the Base model checkpoint into 8 TP checkpoints -PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/base_model.pt -# Example: for EdgeControl checkpoint splitting for post-train. -PYTHONPATH=. python scripts/convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/edge_control.pt -``` -This will generate the TP checkpoints under `checkpoints/checkpoints_tp/*_mp_*.pt`, which we load in the training below. - -### 5. (Optional): Dry-run a Training Job -As a sanity check, run the following command to dry-run an example training job with the above data. The command will generated a full configuration of the experiment. - -```bash -export OUTPUT_ROOT=checkpoints # default value - -# Training from scratch -torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain - -# Post-train from our provided checkpoint (need to first split checkpoint into TP checkpoints as instructed above) -torchrun --nproc_per_node=1 -m cosmos_transfer1.diffusion.training.train --dryrun --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_posttrain -``` - -Explanation of the command: - -- The trainer and the passed (master) config script will, in the background, load the detailed experiment configurations defined in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py`, and register the experiments configurations for all `hint_keys` (control modalities), covering both pretrain and post-train. We use [Hydra](https://hydra.cc/docs/intro/) for advanced configuration composition and overriding. - -- The `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain` corresponds to an experiment name registered in `ctrl_7b_tp_121frames.py`. By specifiying this name, all the detailed config will be generated and then written to `checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain/config.yaml`. - -- To customize your training, see `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to understand how the detailed configs of the model, trainer, dataloader etc. are defined, and edit as needed. - -### 6. Launch Training -Now we can start a real training job! Removing the `--dryrun` and set `--nproc_per_node=8` will start a real training job on 8 GPUs: - -```bash -torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain -``` - -**Config group and override.** An `experiment` determines a complete group of configuration parameters (model architecture, data, trainer behavior, checkpointing, etc.). Changing the `experiment` value in the command above will decide which ControlNet model is trained, and whether it's pretrain or post-train. For example, replacing the experiment name in the command with `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_depth_block3_posttrain` will post-train the DepthControl model from the downloaded checkpoint instead. - -To customize your training, see the job (experiment) config in `cosmos_transfer1/diffusion/config/training/experiment/ctrl_7b_tp_121frames.py` to understand how they are defined, and edit as needed. - -It is also possible to modify config parameters from the command line. For example: - -```bash -torchrun --nproc_per_node=8 -m cosmos_transfer1.diffusion.training.train --config=cosmos_transfer1/diffusion/config/config_train.py -- experiment=CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain trainer.max_iter=100 checkpoint.save_iter=40 -``` - -This will update the maximum training iterations to 100 (default in the registered experiments: 999999999) and checkpoint saving frequency to 40 (default: 1000). - -**Saving Checkpoints and Resuming Training.** -During training, the checkpoints will be saved in the structure below. Since we use TensorParallel across 8 GPUs, 8 checkpoints will be saved each time. - -``` -checkpoints/cosmos_transfer1_pretrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_edge_block3_pretrain/checkpoints/ -├── iter_{NUMBER}.pt # "master" checkpoint, saving metadata only -├── iter_{NUMBER}_model_mp_0.pt # real TP checkpoints -├── iter_{NUMBER}_model_mp_1.pt -├── ... -├── iter_{NUMBER}_model_mp_7.pt -``` - -Since the `experiment` is uniquely associated with its checkpoint directory, rerunning the same training command after an unexpected interruption will automatically resume from the latest saved checkpoint. - -### 7. Inference Using Trained Models - -**Converting the TP checkpoints to FSDP checkpoint:** To convert Tensor Parallel (TP) checkpoints to Fully Sharded Data Parallel (FSDP) format, use the conversion script `convert_ckpt_tp_to_fsdp.py`. This script requires the same number of GPUs as your TP size (e.g., if you trained with TP_SIZE=8, you need 8 GPUs for conversion). - -Example usage: -```bash -torchrun --nproc_per_node=8 convert_ckpt_tp_to_fsdp.py \ - --experiment CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain \ - --checkpoint-path checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain/checkpoints/iter_000000100.pt -``` - -Optional arguments: -- `--output-directory`: Custom directory for saving FSDP checkpoints (default: automatically generated from checkpoint path) -- `--include-base-model`: Include base model in ControlNet checkpoint (default: False) - -The script will create two files in the output directory: -1. `*_reg_model.pt`: Regular model checkpoint -2. `*_ema_model.pt`: EMA model checkpoint - -The EMA model checkpoint (`*_ema_model.pt`) typically presents better quality results and is recommended for running inference in the next stage. For more details about the conversion process and available options, refer to the script's docstring. - -**Run inference:** Follow the steps in the [inference README](./inference_cosmos_transfer1_7b.md). - -## FAQ -**Q1: What if I want to use my own control input type? How should I modify the code?** -**A1:** Modify the following scripts: -- Add new condition in: - - `cosmos_transfer1/diffusion/conditioner.py` - - `cosmos_transfer1/diffusion/config/transfer/conditioner.py` -- Add data augmentor function in `cosmos_transfer1/diffusion/datasets/augmentors/control_input.py` -- Add new hint key in: - - `cosmos_transfer1/diffusion/inference/inference_utils.py` - - `cosmos_transfer1/diffusion/inference/world_generation_pipeline.py` -- If needed, add preprocessor in `cosmos_transfer1/auxiliary/` and update `cosmos_transfer1/diffusion/inference/preprocessors.py`. diff --git a/cosmos-transfer1/pyproject.toml b/cosmos-transfer1/pyproject.toml deleted file mode 100644 index 3e24887dae47274b747a4d5f6fb6d72911e1a061..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/pyproject.toml +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Note that the build will take some time due to transformer_engine installation which does a lot of CUDA compilation. -# But, at the end, you should be able to do an `import cosmos_transfer1` in python without any issues. - -[build-system] -requires = ["setuptools", "setuptools-scm"] -build-backend = "setuptools.build_meta" - -[project] -name = "cosmos-transfer1" -version = "1.0.0" -authors = [ - {name = "NVIDIA Corporation"}, -] -description = "Cosmos World Foundation Model" -readme = "README.md" -requires-python = ">=3.10" -license = {text = "Apache-2.0"} -dynamic = ["dependencies"] - -[project.urls] -homepage = "https://github.com/nvidia-cosmos/cosmos-transfer1/" -issues = "https://github.com/nvidia-cosmos/cosmos-transfer1/issues" - -[tool.setuptools.dynamic] -dependencies = {file = ["requirements.txt"]} - -[tool.setuptools.packages] -find = {} - -[tool.isort] -profile = "black" -known_third_party = ["third_party"] -skip = ["third_party"] -line_length = 120 - -[tool.black] -line-length = 120 -target_version = ['py310'] -exclude = 'third_party' - -[tool.coverage.report] -include_namespace_packages = true -skip_empty = true -omit = [ - "tests/*", - "legacy/*", - ".venv/*", - "**/test_*.py", - "config.py", - "config-3.10.py" -] diff --git a/cosmos-transfer1/requirements.txt b/cosmos-transfer1/requirements.txt deleted file mode 100644 index 8f8a8298e158a3f51bc5a36f918202858b0a2fc8..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/requirements.txt +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Please keep requirements sorted alphabetically - -apex==0.9.10dev -attr==0.3.2 -attrs==25.3.0 -better_profanity==0.7.0 -cloudpickle==3.1.1 -decord==0.6.0 -dill==0.4.0 -einops==0.8.1 -hydra-core==1.3.2 -imageio[pyav,ffmpeg]==2.37.0 -iopath==0.1.10 -loguru==0.7.3 -matplotlib==3.10.3 -mediapy==1.2.4 -megatron_core==0.10.0 -natsort==8.4.0 -nltk==3.9.1 -numpy==2.2.6 -nvidia_ml_py==12.570.86 -omegaconf==2.3.0 -opencv_python==4.10.0.84 -opencv_python_headless==4.11.0.86 -packaging==25.0 -pandas==2.2.3 -peft==0.14.0 -Pillow==11.2.1 -pillow -pycocotools==2.0.8 -pynvml==12.0.0 -python-magic -PyYAML==6.0.2 -retinaface_py==0.0.2 -rtmlib==0.0.13 -sam2==1.1.0 -termcolor==3.1.0 -torch==2.7.0 -torchvision==0.22.0 -tqdm==4.67.1 -transformers==4.49.0 diff --git a/cosmos-transfer1/requirements_docker.txt b/cosmos-transfer1/requirements_docker.txt deleted file mode 100644 index ccb5a4d000200a43cb57a56f316f52f3f483be7e..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/requirements_docker.txt +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Please keep requirements sorted alphabetically - -av==14.2.0 -better-profanity==0.7.0 -boto3==1.35.99 -diffusers==0.32.2 -einx==0.1.3 -huggingface-hub==0.29.2 -hydra-core==1.3.2 -imageio[pyav,ffmpeg]==2.37.0 -iopath==0.1.10 -ipdb==0.13.13 -loguru==0.7.3 -mediapy==1.2.2 -megatron-core==0.10.0 -natsort==8.4.0 -nltk==3.9.1 -omegaconf==2.3.0 -opencv-python==4.10.0.84 -pandas==2.2.3 -peft==0.14.0 -pillow -pycocotools -python-magic -retinaface-py==0.0.2 -rtmlib==0.0.13 -scikit-image==0.25.2 -sentencepiece==0.2.0 -termcolor==2.5.0 -transformers==4.49.0 diff --git a/cosmos-transfer1/scripts/check_video_links.py b/cosmos-transfer1/scripts/check_video_links.py deleted file mode 100644 index 63fcf13a2fe91c0b2c0fafa23f6017ed8af49576..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/scripts/check_video_links.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re - -import requests - - -def find_md_files(root="."): - for dirpath, _, filenames in os.walk(root): - for f in filenames: - if f.endswith(".md"): - yield os.path.join(dirpath, f) - - -def extract_video_urls(md_file): - with open(md_file, "r", encoding="utf-8") as f: - content = f.read() - return re.findall(r' bool: - """Check if the given key corresponds to a column-parallel parameter.""" - return ( - key.endswith("to_q.0.weight") - or key.endswith("to_k.0.weight") - or key.endswith("to_v.0.weight") - or key.endswith("block.layer1.weight") - ) - - -def is_row(key: str) -> bool: - """Check if the given key corresponds to a row-parallel parameter.""" - return key.endswith("to_out.0.weight") or key.endswith("block.layer2.weight") - - -def native_to_tp(reg_state_dict: Dict[str, Any], tp_size: int) -> List[OrderedDict]: - """Convert a regular state dict to tensor parallel state dicts. - - Args: - reg_state_dict: The regular state dictionary. - tp_size: The number of tensor parallel partitions. - - Returns: - A list of OrderedDicts, each representing a tensor parallel partition. - """ - tp_state_dict = [OrderedDict() for _ in range(tp_size)] - log.info("Converting to TP checkpoint..") - for key, value in reg_state_dict.items(): - if key.endswith("_extra_state"): - continue - - if is_column(key): - for i, item in enumerate(value.chunk(tp_size, dim=0)): - tp_state_dict[i][key] = item - elif is_row(key): - for i, item in enumerate(value.chunk(tp_size, dim=1)): - tp_state_dict[i][key] = item - else: - for i in range(tp_size): - tp_state_dict[i][key] = value - - return tp_state_dict - - -def convert_fsdp_to_tp(path_in: str, path_out: str) -> None: - """Convert an FSDP checkpoint to TP format. - - Args: - path_in: Path to input checkpoint (without _reg_model.pt suffix) - path_out: Path for output checkpoint (without _model_mp_X.pt suffix) - tp_size: Number of tensor parallel partitions - verbose: Whether to show progress bar - - Raises: - FileNotFoundError: If input checkpoint doesn't exist - ValueError: If paths are invalid or tp_size <= 0 - RuntimeError: For other conversion errors - """ - try: - log.info(f"Loading checkpoint from {path_in}..") - native_ckpt = torch.load( - path_in, - map_location=torch.device("cpu"), - weights_only=False, # Load to CPU first; weights_only=False required for newer PyTorch versions - ) - state_dicts = native_to_tp(native_ckpt, TP_SIZE) - except FileNotFoundError: - raise FileNotFoundError(f"Checkpoint file {path_in} not found") - except Exception as e: - raise RuntimeError(f"Error loading checkpoint: {str(e)}") - - log.info("Saving TP checkpoints..") - # Add a dummy grad_scaler and iteration to the checkpoint. Required by the training script. - easy_io.dump({"grad_scaler": {}, "iteration": 0}, f"{path_out}.pt") - for i in tqdm(range(TP_SIZE)): - state_dict = {"model": state_dicts[i], "ema": None} - easy_io.dump(state_dict, f"{path_out}_model_mp_{i}.pt") - - -if __name__ == "__main__": - """ - Example usage: converting a viscontrol model to a TP checkpoint. - - Command: - python convert_ckpt_fsdp_to_tp.py checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control.pt - - This will save the Tensor Parallel (TP) checkpoints as 8 files in the same directory: - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_0.pt - ... - checkpoints/nvidia/Cosmos-Transfer1-7B/vis_control_model_mp_7.pt - """ - if len(sys.argv) != 2: - print("Usage: python convert_ckpt_fsdp_to_tp.py ") - print("Example: python convert_ckpt_fsdp_to_tp.py checkpoints/model.pt") - sys.exit(1) - - checkpoint_path = sys.argv[1] - - # Create checkpoints_tp directory in the same parent directory as the input checkpoint - input_dir = os.path.dirname(checkpoint_path) - tp_ckpt_dir = os.path.join(input_dir, "checkpoints_tp") - os.makedirs(tp_ckpt_dir, exist_ok=True) - - # Use the same basename as input but in the checkpoints_tp directory - out_tp_checkpoint_path = os.path.join(tp_ckpt_dir, os.path.basename(checkpoint_path).replace(".pt", "")) - try: - convert_fsdp_to_tp(checkpoint_path, out_tp_checkpoint_path) - print(f"Conversion completed successfully! See {tp_ckpt_dir}.") - except Exception as e: - print(f"Error during conversion: {str(e)}") - sys.exit(1) diff --git a/cosmos-transfer1/scripts/convert_ckpt_tp_to_fsdp.py b/cosmos-transfer1/scripts/convert_ckpt_tp_to_fsdp.py deleted file mode 100644 index e05da8936c38b71d4e6f2a618f8cf5ca9f2f144d..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/scripts/convert_ckpt_tp_to_fsdp.py +++ /dev/null @@ -1,346 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - torchrun --nproc_per_node=8 convert_ckpt_tp_to_fsdp.py \ - --experiment CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain \ - --checkpoint-path checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain/checkpoints/iter_000000100.pt - -This script is designed to convert a Tensor Parallel (TP) checkpoint -to a Fully Sharded Data Parallel (FSDP) compatible format for a video diffusion model. - -Using experiment `CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain` as an example: -For a model trained with Tensor Parallel (TP), the checkpoints are saved in the following formats: -``` - checkpoint_path = f"checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/{experiment}/checkpoints/iter_000000100_model_mp_0.pt" - checkpoint_path = f"checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/{experiment}/checkpoints/iter_000000100_model_mp_1.pt" - ... - checkpoint_path = f"checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/{experiment}/checkpoints/iter_000000100_model_mp_7.pt" -``` - -where `*_model_mp_0.pt` and `*_model_mp_1.pt` are the model checkpoints for the eight TP ranks. - -This script will load the TP model checkpoint and convert it to a FSDP-compatible format. -The converted checkpoints will be saved -to a new directory `fsdp_checkpoints` under the same experiment directory, e.g., - `checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/{experiment}/fsdp_checkpoints/`. - -It has the following formats: -``` -iter_000000100_reg_model.pt -iter_000000100_ema_model.pt -``` -""" - -import argparse -import os -from collections import OrderedDict - -import torch -import torch.distributed as dist -import torch.nn as nn -import transformer_engine as te -import yaml -from megatron.core import parallel_state - -from cosmos_transfer1.diffusion.config.config_train import make_config -from cosmos_transfer1.diffusion.training.train import instantiate_model -from cosmos_transfer1.utils import log -from cosmos_transfer1.utils.config_helper import override -from cosmos_transfer1.utils.easy_io import easy_io -from cosmos_transfer1.utils.misc import set_random_seed - - -@torch.no_grad -def copy_params_from_tp(model: nn.Module, model_tp: nn.Module, tp_size: int) -> None: - orig_tp_size = parallel_state.get_tensor_model_parallel_world_size() - # create temporary parallel_state for parameters & buffer copy - parallel_state.destroy_model_parallel() - parallel_state.initialize_model_parallel(tensor_model_parallel_size=tp_size) - - match_layers = OrderedDict() - ddp_group = parallel_state.get_data_parallel_group() - tp_group = parallel_state.get_tensor_model_parallel_group() - assert tp_size == parallel_state.get_tensor_model_parallel_world_size(), "TP group init is wrong" - tp_rank = parallel_state.get_tensor_model_parallel_rank() - - def record_match_layer(name, param, param_chunk, policy): - match_layers[name] = { - "shape": list(param.shape), - "copied_name": name, - "copied_shape": list(param_chunk.shape), - "policy": policy, - "type": "param", - } - - for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): - module_name_hierarchy = name.split(".") - submodule_name = ".".join(module_name_hierarchy[:-1]) - submodule = model.get_submodule(submodule_name) - submodule_tp = model_tp.get_submodule(submodule_name) - - if isinstance(submodule, nn.Linear) and isinstance(submodule_tp, te.pytorch.Linear): - # get parallel mode and copy weights - if module_name_hierarchy[-1] == "weight": - if submodule_tp.parallel_mode == "column": - param_chunks = param.chunk(tp_size, dim=0) - record_match_layer(name, param, param_chunks[tp_rank], f"column_rank{tp_rank}") - param_tp_chunks = [torch.zeros_like(param_tp) for _ in range(tp_size)] - dist.all_gather(param_tp_chunks, param_tp, tp_group, async_op=False) - for _tp_rank in range(tp_size): - param_chunks[_tp_rank].copy_(param_tp_chunks[_tp_rank], non_blocking=True) - elif submodule_tp.parallel_mode == "row": - param_chunks = param.chunk(tp_size, dim=1) - record_match_layer(name, param, param_chunks[tp_rank], f"row_rank{tp_rank}") - param_tp_chunks = [torch.zeros_like(param_tp) for _ in range(tp_size)] - dist.all_gather(param_tp_chunks, param_tp, tp_group, async_op=False) - for _tp_rank in range(tp_size): - param_chunks[_tp_rank].copy_(param_tp_chunks[_tp_rank], non_blocking=True) - else: - record_match_layer(name, param, param_tp, "direct") - param.copy_(param_tp, non_blocking=True) - elif module_name_hierarchy[-1] == "bias": - raise NotImplementedError("Bias is not supported yet.") - else: - record_match_layer(name, param, param_tp, "direct") - param.copy_(param_tp, non_blocking=True) - - # Important to also copy buffer as logvar has randomness. - for (name, buffer), (name_tp, buffer_tp) in zip(model.named_buffers(), model_tp.named_buffers()): - if buffer.size() == buffer_tp.size(): - match_layers[name] = { - "shape": buffer.shape, - "copied_name": name_tp, - "copied_shape": buffer_tp.shape, - "policy": "direct", - "type": "buffer", - } - buffer.copy_(buffer_tp, non_blocking=True) - else: - if "bias" in name: - raise NotImplementedError("Bias is not supported yet.") - - if "model_ema" in name: - module_name = name.replace("-", ".") - module_name = module_name.replace("model_ema", "model") - if "column" in match_layers[module_name]["policy"] or "row" in match_layers[module_name]["policy"]: - dim = 0 if "column" in match_layers[module_name]["policy"] else 1 - buffer_chunks = buffer.chunk(tp_size, dim=dim) - buffer_tp_chunks = [torch.zeros_like(buffer_tp) for _ in range(tp_size)] - dist.all_gather(buffer_tp_chunks, buffer_tp, tp_group, async_op=False) - for _tp_rank in range(tp_size): - buffer_chunks[_tp_rank].copy_(buffer_tp_chunks[_tp_rank], non_blocking=True) - else: - log.info(f"{name} is not copied due to size mismatch.") - - dist.barrier(ddp_group) - dist.barrier(tp_group) - # convert match_layers to yaml and save it to disk - yaml_fp = f"/tmp/match_layers_rank{dist.get_rank()}_tp_rank{tp_rank}.yaml" - with open(yaml_fp, "w") as f: - yaml.dump(match_layers, f) - - # recover the original parallel_state - parallel_state.destroy_model_parallel() - parallel_state.initialize_model_parallel(tensor_model_parallel_size=orig_tp_size) - - return - - -def convert_tp_checkpoint_to_fsdp( - experiment: str, - checkpoint_path: str, - output_directory: str, - include_base_model_in_ctrlnet_ckpt: bool = False, -) -> None: - """ - Convert a Tensor Parallel (TP) checkpoint to a Fully Sharded Data Parallel (FSDP) compatible format. - - This function performs the following steps: - 1. Loads a TP model checkpoint - 2. Initializes a non-TP model - 3. Converts the checkpoint from TP format to FSDP compatible format - 4. Verifies the conversion by comparing outputs, losses, and gradients - - Args: - experiment (str): The name of the experiment for which to convert the checkpoint. - checkpoint_path (str): The path to the TP checkpoint file. - output_directory (str): The directory where the converted FSDP checkpoint will be saved. - - Raises: - ValueError: If the conversion process fails or if the verification step detects significant discrepancies. - - Note: - This function assumes that the necessary configurations and dependencies are properly set up. - It uses bfloat16 as the default dtype for better performance and memory efficiency. - - """ - log.info(f"Converting TP checkpoint to FSDP for experiment: {experiment}") - - # Clean up any existing parallel state - parallel_state.destroy_model_parallel() - - # Set the default dtype to bfloat16 for better performance and memory efficiency - dtype = torch.bfloat16 - torch.set_default_dtype(dtype) - - # Initialize and load the Tensor Parallel (TP) model - config_tp = make_config() - override_tp = [ - "--", - f"experiment={experiment}", - f"checkpoint.load_path={checkpoint_path}", - "checkpoint.load_training_state=False", - ] - config_tp = override( - config_tp, - override_tp, - ) - - # Initialize trainer, model, optimizer, scheduler, and grad scaler for TP - trainer_tp = config_tp.trainer.type(config_tp) - # tp_group = parallel_state.get_tensor_model_parallel_group() - tp_size = parallel_state.get_tensor_model_parallel_world_size() - # global_tp_src_rank = parallel_state.get_tensor_model_parallel_src_rank() - global_rank = dist.get_rank() - - # Set random seed by global rank to ensure diversity within TP groups - set_random_seed(global_rank) - model_tp = instantiate_model(config_tp, trainer_tp).cuda() - optimizer_tp, scheduler_tp = model_tp.init_optimizer_scheduler(config_tp.optimizer, config_tp.scheduler) - grad_scaler_tp = torch.amp.GradScaler("cuda", **config_tp.trainer.grad_scaler_args) - - # Load checkpoint and prepare model for training - log.info("Loading checkpoint...") - trainer_tp.checkpointer.load(model_tp, optimizer_tp, scheduler_tp, grad_scaler_tp) - model_tp.on_train_start() - - # Initialize and prepare the non-TP model - parallel_state.destroy_model_parallel() - - config = make_config() - config = override( - config, - [ - "--", - f"experiment={experiment}", - "ckpt_klass=multi_rank", - "checkpoint.load_path=''", - "model_parallel.tensor_model_parallel_size=1", - "model_parallel.sequence_parallel=False", - ], - ) - - # Initialize non-TP model and copy parameters from TP model - trainer = config.trainer.type(config) - model = instantiate_model(config, trainer).cuda() - model.on_train_start() - copy_params_from_tp(model, model_tp, tp_size=tp_size) - - # Save the converted model checkpoints - if torch.distributed.get_rank() == 0: - # Save regular model checkpoint - checkpoint_name = os.path.basename(checkpoint_path) - reg_model_checkpoint_name = checkpoint_name.replace(".pt", "_reg_model.pt") - reg_model_path = os.path.join(output_directory, reg_model_checkpoint_name) - - # Save EMA model checkpoint with necessary post-processing - ema_state_dict = {k.replace("-", "."): v for k, v in model.state_dict()["ema"].items()} - for key in ["net.pos_embedder.seq", "logvar.0.freqs", "logvar.0.phases"]: - ema_state_dict[key] = model.state_dict()["model"][key] - - if include_base_model_in_ctrlnet_ckpt: - # Copy base model keys to ema dict for controlnets. - for key in model.state_dict()["model"].keys(): - if key.startswith("base_model") and key not in ema_state_dict: - ema_state_dict[key] = model.state_dict()["model"][key] - - ema_model_checkpoint_name = checkpoint_name.replace(".pt", "_ema_model.pt") - else: - ema_model_checkpoint_name = checkpoint_name.replace(".pt", "_ema_model_only.pt") - ema_model_path = os.path.join(output_directory, ema_model_checkpoint_name) - easy_io.dump(ema_state_dict, ema_model_path) - - # clean up the base model in the state dict if include_base_model_in_ctrlnet_ckpt is False - if not include_base_model_in_ctrlnet_ckpt: - # Get the state dict first - state_dict = model.state_dict()["model"] - # Create a new dict without base_model keys - filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("base_model")} - - # Save the filtered state dict directly - easy_io.dump(filtered_state_dict, reg_model_path) - else: - # Save the original state dict - easy_io.dump(model.state_dict()["model"], reg_model_path) - - log.info( - f"Conversion complete. FSDP-compatible checkpoints saved for experiment: {experiment}\n" - f"Regular model saved at {reg_model_path}\n" - f"EMA model saved at {ema_model_path}" - ) - - -if __name__ == "__main__": - # Set up argument parser - parser = argparse.ArgumentParser(description="Convert Tensor Parallel checkpoint to FSDP format") - parser.add_argument( - "--experiment", - type=str, - default="CTRL_7Bv1pt3_lvg_tp_121frames_control_input_seg_block3_posttrain", - help="Name of the experiment", - ) - parser.add_argument( - "--checkpoint-path", - type=str, - help="Path to the checkpoint file. If not provided, will be constructed using experiment name", - ) - parser.add_argument( - "--output-directory", - type=str, - help="Directory to save the converted checkpoints. If not provided, will be constructed using checkpoint path", - ) - parser.add_argument( - "--include-base-model", - action="store_true", - default=False, - help="Include base model in controlnet checkpoint (default: False)", - ) - - args = parser.parse_args() - - # Set default checkpoint path if not provided - if args.checkpoint_path is None: - args.checkpoint_path = ( - f"checkpoints/cosmos_transfer1_posttrain/CTRL_7Bv1_lvg/{args.experiment}/checkpoints/iter_000000100.pt" - ) - - # Set default output directory if not provided - if args.output_directory is None: - args.output_directory = os.path.dirname(args.checkpoint_path).replace( - f"{args.experiment}/checkpoints", f"{args.experiment}/fsdp_checkpoints" - ) - - # Create output directory - os.makedirs(args.output_directory, exist_ok=True) - - # Convert checkpoint - convert_tp_checkpoint_to_fsdp( - experiment=args.experiment, - checkpoint_path=args.checkpoint_path, - output_directory=args.output_directory, - include_base_model_in_ctrlnet_ckpt=args.include_base_model, - ) diff --git a/cosmos-transfer1/scripts/download_diffusion_example_data.py b/cosmos-transfer1/scripts/download_diffusion_example_data.py deleted file mode 100644 index 70bd4ba23a538d8f12dfa1fe7402ab9e33a7cd08..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/scripts/download_diffusion_example_data.py +++ /dev/null @@ -1,121 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import os - -import ffmpeg -from pytubefix import YouTube - -"""example command -CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_diffusion_example_data.py --dataset_path datasets/hdvila --N_videos 128 --do_download --do_clip -""" - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Download example (hdvila) data for posttraining") - parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") - parser.add_argument("--N_videos", type=int, default=128, help="Number of videos to download") - parser.add_argument("--do_download", action="store_true", help="Download the videos") - parser.add_argument("--do_clip", action="store_true", help="Clip the videos") - return parser.parse_args() - - -def convert_time_to_seconds(time_str) -> int: - h, m, s = map(float, time_str.split(":")) - ms = int(time_str.split(".")[-1]) if "." in time_str else 0 - return int(h * 3600 + m * 60 + s) + ms / 1000 - - -def download_data(args) -> None: - urls_set = set() - download_count = 0 - - videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") - os.makedirs(videos_orig_dir, exist_ok=True) - videos_dir = os.path.join(args.dataset_path, "videos") - os.makedirs(videos_dir, exist_ok=True) - metas_dir = os.path.join(args.dataset_path, "metas") - os.makedirs(metas_dir, exist_ok=True) - - hdvila_jsonl_path = os.path.join(args.dataset_path, "hdvila-100M.jsonl") - with open(hdvila_jsonl_path, "r") as fp: - for line in fp: - json_object = json.loads(line) - url = json_object["url"] - if url not in urls_set: # download videos with unique urls - yt = YouTube(json_object["url"]) - try: - # Download a video - yt.streams.get_highest_resolution().download( - output_path=videos_orig_dir, filename=json_object["video_id"] + ".mp4" - ) - download_count += 1 - urls_set.add(url) - print(f"Downloaded videos: {download_count}/{args.N_videos}") - - # Save metadata - caption and whole metadata - meta_txt_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".txt")) - with open(meta_txt_name, "w") as fp: - fp.write(json_object["caption"]) - meta_json_name = os.path.join(metas_dir, json_object["clip_id"].replace(".mp4", ".json")) - with open(meta_json_name, "w") as fp: - json.dump(json_object, fp) - except Exception as e: - print(e) - continue - - if len(urls_set) >= args.N_videos: - break - - -def clip_data(args) -> None: - videos_orig_dir = os.path.join(args.dataset_path, "videos_orig") - videos_dir = os.path.join(args.dataset_path, "videos") - os.makedirs(videos_dir, exist_ok=True) - metas_dir = os.path.join(args.dataset_path, "metas") - - metas_list = [ - os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".json") - ] - videos_orig_list = [ - os.path.join(videos_orig_dir, filename) - for filename in sorted(os.listdir(videos_orig_dir)) - if filename.endswith(".mp4") - ] - - for meta_filename, video_orig_filename in zip(metas_list, videos_orig_list): - with open(meta_filename, "r") as fp: - metadata = json.load(fp) - - # Convert time strings to seconds - start_time = convert_time_to_seconds(metadata["span_start"]) - end_time = convert_time_to_seconds(metadata["span_end"]) - # Clip the video - clip_name = os.path.join(videos_dir, metadata["clip_id"]) - ffmpeg.input(video_orig_filename, ss=start_time, t=end_time - start_time).output(clip_name).run() - - -def main(args) -> None: - if args.do_download: - download_data(args) - if args.do_clip: - clip_data(args) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/cosmos-transfer1/scripts/format.sh b/cosmos-transfer1/scripts/format.sh deleted file mode 100644 index 42c3e2bd41407ab284b14bf2cb0bc76d67785374..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/scripts/format.sh +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -cosmos_root=$(git rev-parse --show-toplevel) -venv_folder=$cosmos_root/.venv -scripts_folder=$cosmos_root/scripts - -echo "Formatting $cosmos_root" -if [ ! -d "$scripts_folder" ]; then - echo "script has to be called from repo root dir!" - exit -1 -fi - -if [ ! -d "$venv_folder" ]; then - mkdir -p $venv_folder - python3 -m pip install virtualenv - python3 -m venv $venv_folder -fi - -source $venv_folder/bin/activate - -dependencies=($(pip freeze | grep -E 'pre-commit==3.7.1|flake8==7.1.0|black==24.4.2|isort==5.13.2|loguru|termcolor')) -if [ "${#dependencies[@]}" -ne 6 ]; then - python3 -m pip install --upgrade pip - python3 -m pip install pre-commit==3.7.1 - python3 -m pip install flake8==7.1.0 - python3 -m pip install black==24.4.2 - python3 -m pip install isort==5.13.2 - python3 -m pip install loguru - python3 -m pip install termcolor -fi -set -e -python3 $scripts_folder/ip_header.py -pre-commit install-hooks -pre-commit run --all diff --git a/cosmos-transfer1/scripts/get_t5_embeddings.py b/cosmos-transfer1/scripts/get_t5_embeddings.py deleted file mode 100644 index 53b6ebe889340e9360a946fa3022766dcf027c36..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/scripts/get_t5_embeddings.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import pickle -from typing import Tuple - -import numpy as np -import torch -from transformers import T5EncoderModel, T5TokenizerFast - -"""example command -CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/get_t5_embeddings.py --dataset_path datasets/hdvila -""" - - -def parse_args() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Compute T5 embeddings for text prompts") - parser.add_argument("--dataset_path", type=str, default="datasets/hdvila", help="Root path to the dataset") - parser.add_argument("--max_length", type=int, default=512, help="Maximum length of the text embedding") - parser.add_argument( - "--pretrained_model_name_or_path", type=str, default="google-t5/t5-11b", help="T5 model name or the local path" - ) - parser.add_argument("--cache_dir", type=str, default="checkpoints", help="Directory to cache the T5 model") - return parser.parse_args() - - -def init_t5( - pretrained_model_name_or_path: str = "google-t5/t5-11b", max_length: int = 512, cache_dir: str = "~/.cache" -) -> Tuple[T5TokenizerFast, T5EncoderModel]: - """Initialize and return the T5 tokenizer and text encoder.""" - tokenizer = T5TokenizerFast.from_pretrained( - pretrained_model_name_or_path, model_max_length=max_length, cache_dir=cache_dir - ) - text_encoder = T5EncoderModel.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) - text_encoder.to("cuda") - text_encoder.eval() - return tokenizer, text_encoder - - -@torch.inference_mode() -def encode_for_batch(tokenizer, encoder, prompts: list[str], max_length=512) -> list: - """ - Encode a batch of text prompts to a batch of T5 embeddings. - Parameters: - tokenizer: T5 embedding tokenizer. - encoder: T5 embedding text encoder. - prompts: A batch of text prompts. - max_length: Sequence length of text embedding (defaults to 512). - """ - - batch_encoding = tokenizer.batch_encode_plus( - prompts, - return_tensors="pt", - truncation=True, - padding="max_length", - max_length=max_length, - return_length=True, - return_offsets_mapping=False, - ) - - # We expect all the processing is done on GPU. - input_ids = batch_encoding.input_ids.cuda() - attn_mask = batch_encoding.attention_mask.cuda() - - outputs = encoder(input_ids=input_ids, attention_mask=attn_mask) - encoded_text = outputs.last_hidden_state - - lengths = attn_mask.sum(dim=1).cpu() - for batch_id in range(encoded_text.shape[0]): - encoded_text[batch_id][lengths[batch_id] :] = 0 - - encoded_text = encoded_text.cpu().numpy().astype(np.float16) - encoded_text = encoded_text[:, :max_length] - - # trim zeros to save space - encoded_text = [encoded_text[batch_id][: lengths[batch_id]] for batch_id in range(encoded_text.shape[0])] - - return encoded_text - - -def main(args) -> None: - metas_dir = os.path.join(args.dataset_path, "metas") - metas_list = [ - os.path.join(metas_dir, filename) for filename in sorted(os.listdir(metas_dir)) if filename.endswith(".txt") - ] - - t5_xxl_dir = os.path.join(args.dataset_path, "t5_xxl") - os.makedirs(t5_xxl_dir, exist_ok=True) - - # Initialize T5 - tokenizer, text_encoder = init_t5(cache_dir=args.cache_dir) - - for meta_filename in metas_list: - t5_xxl_filename = os.path.join(t5_xxl_dir, os.path.basename(meta_filename).replace(".txt", ".pickle")) - if os.path.exists(t5_xxl_filename): - # Skip if the file already exists - continue - - with open(meta_filename, "r") as fp: - prompt = fp.read().strip() - - # Compute T5 embeddings - encoded_text = encode_for_batch(tokenizer, text_encoder, [prompt]) - - # Save T5 embeddings as pickle file - with open(t5_xxl_filename, "wb") as fp: - pickle.dump(encoded_text, fp) - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/cosmos-transfer1/scripts/ip_header.py b/cosmos-transfer1/scripts/ip_header.py deleted file mode 100644 index 80e6d4a025151419db99c45e05f255d5235bc0da..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/scripts/ip_header.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import sys - -import termcolor - -parser = argparse.ArgumentParser(description="Cosmos IP header checker/fixer") -parser.add_argument("--fix", action="store_true", help="apply the fixes instead of checking") -args, files_to_check = parser.parse_known_args() - - -def get_header(ext: str = "py", old: str | bool = False) -> list[str]: - header = [ - "SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.", - "SPDX-License-Identifier: Apache-2.0", - "", - 'Licensed under the Apache License, Version 2.0 (the "License");', - "you may not use this file except in compliance with the License.", - "You may obtain a copy of the License at", - "", - "http://www.apache.org/licenses/LICENSE-2.0", - "", - "Unless required by applicable law or agreed to in writing, software", - 'distributed under the License is distributed on an "AS IS" BASIS,', - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", - "See the License for the specific language governing permissions and", - "limitations under the License.", - ] - if ext == ".py" and old: - if old == "single": - header = ["'''"] + header + ["'''"] - elif old == "double": - header = ['"""'] + header + ['"""'] - else: - raise NotImplementedError - elif ext in (".py", ".yaml"): - header = [("# " + line if line else "#") for line in header] - elif ext in (".c", ".cpp", ".cu", ".h", ".cuh"): - header = ["/*"] + [(" * " + line if line else " *") for line in header] + [" */"] - else: - raise NotImplementedError - return header - - -def apply_file(file: str, results: dict[str, int], fix: bool = False) -> None: - if file.endswith("__init__.py"): - return - ext = os.path.splitext(file)[1] - content = open(file).read().splitlines() - header = get_header(ext=ext) - if fix: - if _check_header(content, header): - return - print(f"fixing: {file}") - while len(content) > 0 and not content[0]: - content.pop(0) - content = header + [""] + content - with open(file, "w") as file_obj: - for line in content: - file_obj.write(line + "\n") - else: - if not _check_header(content, header): - bad_header = colorize("BAD HEADER", color="red", bold=True) - print(f"{bad_header}: {file}") - results[file] = 1 - else: - results[file] = 0 - - -def traverse_directory(path: str, results: dict[str, int], fix: bool = False, substrings_to_skip=[]) -> None: - files = os.listdir(path) - for file in files: - full_path = os.path.join(path, file) - if os.path.isdir(full_path): - traverse_directory(full_path, results, fix=fix, substrings_to_skip=substrings_to_skip) - elif os.path.isfile(full_path): - ext = os.path.splitext(file)[1] - to_skip = any(substr in full_path for substr in substrings_to_skip) - if not to_skip and ext in (".py", ".yaml", ".c", ".cpp", ".cu", ".h", ".cuh"): - apply_file(full_path, results, fix=fix) - else: - raise NotImplementedError - - -def _check_header(content: list[str], header: list[str]) -> bool: - if content[: len(header)] != header: - return False - if len(content) > len(header): - if len(content) == len(header) + 1: - return False - if not (content[len(header)] == "" and content[len(header) + 1] != ""): - return False - return True - - -def colorize(x: str, color: str, bold: bool = False) -> str: - return termcolor.colored(str(x), color=color, attrs=("bold",) if bold else None) # type: ignore - - -if __name__ == "__main__": - if not files_to_check: - files_to_check = [ - "cosmos_transfer1/auxiliary", - "cosmos_transfer1/diffusion", - "cosmos_transfer1/utils", - ] - - for file in files_to_check: - assert os.path.isfile(file) or os.path.isdir(file), f"{file} is neither a directory or a file!" - - substrings_to_skip = ["prompt_upsampler"] - results = dict() - for file in files_to_check: - if os.path.isfile(file): - apply_file(file, results, fix=args.fix) - elif os.path.isdir(file): - traverse_directory(file, results, fix=args.fix, substrings_to_skip=substrings_to_skip) - else: - raise NotImplementedError - - if any(results.values()): - sys.exit(1) diff --git a/cosmos-transfer1/scripts/test_environment.py b/cosmos-transfer1/scripts/test_environment.py deleted file mode 100644 index 13217020b22f14a3eba7c503196ddfebaf394ff1..0000000000000000000000000000000000000000 --- a/cosmos-transfer1/scripts/test_environment.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import importlib -import os -import sys - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--training", - action="store_true", - help="Whether to check training-specific dependencies", - ) - return parser.parse_args() - - -def check_packages(package_list): - global all_success - for package in package_list: - try: - _ = importlib.import_module(package) - except Exception: - print(f"\033[91m[ERROR]\033[0m Package not successfully imported: \033[93m{package}\033[0m") - all_success = False - else: - print(f"\033[92m[SUCCESS]\033[0m {package} found") - - -args = parse_args() - -if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): - detected = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" - print(f"\033[91m[ERROR]\033[0m Python 3.10+ is required. You have: \033[93m{detected}\033[0m") - sys.exit(1) - -if "CONDA_PREFIX" not in os.environ: - print( - "\033[93m[WARNING]\033[0m CONDA_PREFIX is not set. When manually installed, Cosmos should run under the cosmos-transfer1 conda environment (see INSTALL.md). This warning can be ignored when running in the container." - ) - -print("Attempting to import critical packages...") - -packages = ["torch", "torchvision", "transformers", "megatron.core", "transformer_engine", "vllm", "pandas"] -packages_training = [ - "apex.multi_tensor_apply", -] -all_success = True - -check_packages(packages) -if args.training: - check_packages(packages_training) - -if all_success: - print("-----------------------------------------------------------") - print("\033[92m[SUCCESS]\033[0m Cosmos environment setup is successful!") diff --git a/cosmos-transfer1/scripts/download_checkpoints.py b/download_checkpoints.py similarity index 74% rename from cosmos-transfer1/scripts/download_checkpoints.py rename to download_checkpoints.py index 178d8e086c34ea912bab8c28f1ed4a39dc4636ee..cf9a72a3dc3cb15d312c385f1275dc47eaa65e92 100644 --- a/cosmos-transfer1/scripts/download_checkpoints.py +++ b/download_checkpoints.py @@ -1,39 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" Download checkpoints from HuggingFace Hub. - -This file downloads the checkpoints specified in the `cosmos_transfer1.checkpoints` module. - -Usage: - - CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_checkpoints.py --output_dir --hf_token - If the token is not provided, it will try to read from the `HF_TOKEN` environment variable. - -""" - -import argparse import hashlib import os import pathlib - -from huggingface_hub import login, snapshot_download +from typing import Literal # Import the checkpoint paths from cosmos_transfer1 import checkpoints from cosmos_transfer1.utils import log +from huggingface_hub import login, snapshot_download def download_checkpoint(checkpoint: str, output_dir: str) -> None: @@ -109,32 +82,30 @@ def get_md5_checksum(output_dir, model_name): return True -def main(): - parser = argparse.ArgumentParser(description="Download checkpoints from HuggingFace Hub") - parser.add_argument("--hf_token", type=str, help="HuggingFace token", default=os.environ.get("HF_TOKEN")) - parser.add_argument( - "--output_dir", type=str, help="Directory to store the downloaded checkpoints", default="./checkpoints" - ) - parser.add_argument( - "--model", type=str, help="Model type to download", default="all", choices=["all", "7b", "7b_av"] - ) - args = parser.parse_args() +def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkpoints", model: Literal["all", "7b", "7b_av"] = "all"): + """ + Download checkpoints from HuggingFace Hub - if args.hf_token: - login(token=args.hf_token) + :param str hf_token: HuggingFace token + :param str output_dir: Directory to store the downloaded checkpoints + :param str model: Model type to download + """ + + if hf_token: + login(token=hf_token) checkpoint_vars = [] # Get all variables from the checkpoints module for name in dir(checkpoints): obj = getattr(checkpoints, name) if isinstance(obj, str) and "CHECKPOINT" in name and "PATH" not in name: - if args.model != "all" and name in [ + if model != "all" and name in [ "COSMOS_TRANSFER1_7B_CHECKPOINT", "COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT", ]: - if args.model == "7b" and name == "COSMOS_TRANSFER1_7B_CHECKPOINT": + if model == "7b" and name == "COSMOS_TRANSFER1_7B_CHECKPOINT": checkpoint_vars.append(obj) - elif args.model == "7b_av" and name in [ + elif model == "7b_av" and name in [ "COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT", "COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT", ]: @@ -146,8 +117,4 @@ def main(): # Download each checkpoint for checkpoint in checkpoint_vars: - download_checkpoint(checkpoint, args.output_dir) - - -if __name__ == "__main__": - main() + download_checkpoint(checkpoint, output_dir) diff --git a/helper.py b/helper.py index 6872ad58c2be43750283bbeb6559cc98f2152432..5e3d41b1be76b99728ab6bcb757eafb38f59545a 100644 --- a/helper.py +++ b/helper.py @@ -1,6 +1,6 @@ import argparse import sys -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional sys.path.append("./cosmos-transfer1") @@ -30,8 +30,8 @@ def parse_arguments( input_video_path: str = "", num_input_frames: int = 1, sigma_max: float = 70.0, - blur_strength: str = "medium", - canny_threshold: str = "medium", + blur_strength: Literal["very_low", "low", "medium", "high", "very_high"] = "medium", + canny_threshold: Literal["very_low", "low", "medium", "high", "very_high"] = "medium", is_av_sample: bool = False, checkpoint_dir: str = "checkpoints", tokenizer_dir: str = "Cosmos-Tokenize1-CV8x8x8-720p", @@ -43,7 +43,7 @@ def parse_arguments( guidance: float = 5, fps: int = 24, seed: int = 1, - num_gpus: int = 1, + num_gpus: Literal[1] = 1, offload_diffusion_transformer: bool = False, offload_text_encoder_model: bool = False, offload_guardrail_models: bool = False, @@ -83,10 +83,6 @@ def parse_arguments( :param bool use_distilled: Use distilled ControlNet model variant """ - assert num_input_frames in [1] - assert blur_strength in ["very_low", "low", "medium", "high", "very_high"] - assert canny_threshold in ["very_low", "low", "medium", "high", "very_high"] - cmd_args = argparse.Namespace( prompt=prompt, negative_prompt=negative_prompt, diff --git a/requirements.txt b/requirements.txt index e906a039e69a1060d37355436aa9305790ad54a7..a27a7919e1c765e1046058221c5007f75dc19e13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,45 +1,10 @@ git+https://github.com/huggingface/diffusers.git +transformers accelerate sentencepiece safetensors +torchvision git+https://github.com/yiyixuxu/cosmos-guardrail.git +peft -apex==0.9.10dev -attr==0.3.2 -attrs==25.3.0 -better_profanity==0.7.0 -cloudpickle==3.1.1 -decord==0.6.0 -dill==0.4.0 -einops==0.8.1 -hydra-core==1.3.2 -imageio[pyav,ffmpeg]==2.37.0 -iopath==0.1.10 -loguru==0.7.3 -matplotlib==3.10.3 -mediapy==1.2.4 -megatron_core==0.10.0 -natsort==8.4.0 -nltk==3.9.1 -numpy==2.2.6 -nvidia_ml_py==12.570.86 -omegaconf==2.3.0 -opencv_python==4.10.0.84 -opencv_python_headless==4.11.0.86 -packaging==25.0 -pandas==2.2.3 -peft==0.14.0 -Pillow==11.2.1 -pillow -pycocotools==2.0.8 -pynvml==12.0.0 -python-magic -PyYAML==6.0.2 -retinaface_py==0.0.2 -rtmlib==0.0.13 -sam2==1.1.0 -termcolor==3.1.0 -torch==2.7.0 -torchvision==0.22.0 -tqdm==4.67.1 -transformers==4.49.0 +git+https://github.com/nvidia-cosmos/cosmos-transfer1