harry900000 commited on
Commit
226c7c9
·
1 Parent(s): e22a639

add cosmos-tranfer1/ into repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +203 -0
  2. app.py +11 -2
  3. cosmos_transfer1/auxiliary/depth_anything/inference/__init__.py +0 -0
  4. cosmos_transfer1/auxiliary/depth_anything/inference/depth_anything_pipeline.py +55 -0
  5. cosmos_transfer1/auxiliary/depth_anything/model/__init__.py +0 -0
  6. cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py +151 -0
  7. cosmos_transfer1/auxiliary/guardrail/README.md +17 -0
  8. cosmos_transfer1/auxiliary/guardrail/__init__.py +14 -0
  9. cosmos_transfer1/auxiliary/guardrail/aegis/__init__.py +14 -0
  10. cosmos_transfer1/auxiliary/guardrail/aegis/aegis.py +135 -0
  11. cosmos_transfer1/auxiliary/guardrail/aegis/categories.py +192 -0
  12. cosmos_transfer1/auxiliary/guardrail/blocklist/__init__.py +14 -0
  13. cosmos_transfer1/auxiliary/guardrail/blocklist/blocklist.py +216 -0
  14. cosmos_transfer1/auxiliary/guardrail/blocklist/utils.py +45 -0
  15. cosmos_transfer1/auxiliary/guardrail/common/__init__.py +0 -0
  16. cosmos_transfer1/auxiliary/guardrail/common/core.py +71 -0
  17. cosmos_transfer1/auxiliary/guardrail/common/io_utils.py +78 -0
  18. cosmos_transfer1/auxiliary/guardrail/common/presets.py +75 -0
  19. cosmos_transfer1/auxiliary/guardrail/face_blur_filter/__init__.py +14 -0
  20. cosmos_transfer1/auxiliary/guardrail/face_blur_filter/blur_utils.py +35 -0
  21. cosmos_transfer1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py +225 -0
  22. cosmos_transfer1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py +117 -0
  23. cosmos_transfer1/auxiliary/guardrail/llamaGuard3/__init__.py +14 -0
  24. cosmos_transfer1/auxiliary/guardrail/llamaGuard3/categories.py +31 -0
  25. cosmos_transfer1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py +122 -0
  26. cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/__init__.py +14 -0
  27. cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py +60 -0
  28. cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py +185 -0
  29. cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py +46 -0
  30. cosmos_transfer1/auxiliary/human_keypoint/human_keypoint.py +155 -0
  31. cosmos_transfer1/auxiliary/robot_augmentation/README.md +112 -0
  32. cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py +577 -0
  33. cosmos_transfer1/auxiliary/sam2/sam2_model.py +392 -0
  34. cosmos_transfer1/auxiliary/sam2/sam2_pipeline.py +126 -0
  35. cosmos_transfer1/auxiliary/sam2/sam2_utils.py +168 -0
  36. cosmos_transfer1/auxiliary/tokenizer/inference/__init__.py +14 -0
  37. cosmos_transfer1/auxiliary/tokenizer/inference/image_cli.py +188 -0
  38. cosmos_transfer1/auxiliary/tokenizer/inference/image_lib.py +124 -0
  39. cosmos_transfer1/auxiliary/tokenizer/inference/utils.py +402 -0
  40. cosmos_transfer1/auxiliary/tokenizer/inference/video_cli.py +210 -0
  41. cosmos_transfer1/auxiliary/tokenizer/inference/video_lib.py +146 -0
  42. cosmos_transfer1/auxiliary/tokenizer/modules/__init__.py +61 -0
  43. cosmos_transfer1/auxiliary/tokenizer/modules/distributions.py +42 -0
  44. cosmos_transfer1/auxiliary/tokenizer/modules/layers2d.py +329 -0
  45. cosmos_transfer1/auxiliary/tokenizer/modules/layers3d.py +969 -0
  46. cosmos_transfer1/auxiliary/tokenizer/modules/patching.py +311 -0
  47. cosmos_transfer1/auxiliary/tokenizer/modules/quantizers.py +513 -0
  48. cosmos_transfer1/auxiliary/tokenizer/modules/utils.py +116 -0
  49. cosmos_transfer1/auxiliary/tokenizer/networks/__init__.py +39 -0
  50. cosmos_transfer1/auxiliary/tokenizer/networks/configs.py +147 -0
.gitignore ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Marimo
198
+ marimo/_static/
199
+ marimo/_lsp/
200
+ __marimo__/
201
+
202
+ # Streamlit
203
+ .streamlit/secrets.toml
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  from typing import List, Tuple
3
 
4
  import gradio as gr
@@ -33,14 +35,16 @@ download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
33
  from test_environment import main as check_environment
34
  from test_environment import setup_environment
35
 
36
- setup_environment()
37
 
38
  # setup env
39
  os.environ["CUDA_HOME"] = "/usr/local/cuda"
40
  os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
41
  os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
42
 
43
- check_environment()
 
 
 
44
 
45
  os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
46
 
@@ -279,6 +283,9 @@ def generate_video(
279
  else:
280
  actual_seed = seed
281
 
 
 
 
282
  args, control_inputs = parse_arguments(
283
  controlnet_specs_in={
284
  "hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
@@ -294,6 +301,8 @@ def generate_video(
294
  seed=seed,
295
  )
296
  videos, prompts = inference(args, control_inputs)
 
 
297
 
298
  video = videos[0]
299
  return video, video, actual_seed
 
1
  import os
2
+ import sys
3
+ import time
4
  from typing import List, Tuple
5
 
6
  import gradio as gr
 
35
  from test_environment import main as check_environment
36
  from test_environment import setup_environment
37
 
 
38
 
39
  # setup env
40
  os.environ["CUDA_HOME"] = "/usr/local/cuda"
41
  os.environ["LD_LIBRARY_PATH"] = "$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH"
42
  os.environ["PATH"] = "$CUDA_HOME/bin:/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:$PATH"
43
 
44
+ if not check_environment():
45
+ setup_environment()
46
+ if not check_environment():
47
+ sys.exit(1)
48
 
49
  os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
50
 
 
283
  else:
284
  actual_seed = seed
285
 
286
+ log.info(f"actual_seed: {actual_seed}")
287
+
288
+ start_time = time.time()
289
  args, control_inputs = parse_arguments(
290
  controlnet_specs_in={
291
  "hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
 
301
  seed=seed,
302
  )
303
  videos, prompts = inference(args, control_inputs)
304
+ end_time = time.time()
305
+ log.info(f"Time taken: {end_time - start_time} s")
306
 
307
  video = videos[0]
308
  return video, video, actual_seed
cosmos_transfer1/auxiliary/depth_anything/inference/__init__.py ADDED
File without changes
cosmos_transfer1/auxiliary/depth_anything/inference/depth_anything_pipeline.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+
18
+ from PIL import Image
19
+
20
+ from cosmos_transfer1.auxiliary.depth_anything.model.depth_anything import DepthAnythingModel
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Depth Estimation using Depth Anything V2")
25
+ parser.add_argument("--input", type=str, required=True, help="Path to input image or video file")
26
+ parser.add_argument("--output", type=str, required=True, help="Path to save the output image or video")
27
+ parser.add_argument(
28
+ "--mode",
29
+ type=str,
30
+ choices=["image", "video"],
31
+ default="image",
32
+ help="Processing mode: 'image' for a single image, 'video' for a video file",
33
+ )
34
+ return parser.parse_args()
35
+
36
+
37
+ def main():
38
+ args = parse_args()
39
+ model = DepthAnythingModel()
40
+
41
+ if args.mode == "image":
42
+ # Load the input image and predict its depth
43
+ image = Image.open(args.input).convert("RGB")
44
+ depth_image = model.predict_depth(image)
45
+ depth_image.save(args.output)
46
+ print(f"Depth image saved to {args.output}")
47
+ elif args.mode == "video":
48
+ # Process the video and save the output
49
+ out_path = model.predict_depth_video(args.input, args.output)
50
+ if out_path:
51
+ print(f"Depth video saved to {out_path}")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
cosmos_transfer1/auxiliary/depth_anything/model/__init__.py ADDED
File without changes
cosmos_transfer1/auxiliary/depth_anything/model/depth_anything.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import cv2
19
+ import imageio
20
+ import numpy as np
21
+ import torch
22
+ from PIL import Image
23
+ from transformers import AutoImageProcessor, AutoModelForDepthEstimation
24
+
25
+ from cosmos_transfer1.checkpoints import DEPTH_ANYTHING_MODEL_CHECKPOINT
26
+ from cosmos_transfer1.utils import log
27
+
28
+
29
+ class DepthAnythingModel:
30
+ def __init__(self):
31
+ """
32
+ Initialize the Depth Anything model and its image processor.
33
+ """
34
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ # Load image processor and model with half precision
36
+ print(f"Loading Depth Anything model - {DEPTH_ANYTHING_MODEL_CHECKPOINT}...")
37
+ self.image_processor = AutoImageProcessor.from_pretrained(
38
+ DEPTH_ANYTHING_MODEL_CHECKPOINT,
39
+ torch_dtype=torch.float16,
40
+ trust_remote_code=True,
41
+ )
42
+ self.model = AutoModelForDepthEstimation.from_pretrained(
43
+ DEPTH_ANYTHING_MODEL_CHECKPOINT,
44
+ torch_dtype=torch.float16,
45
+ trust_remote_code=True,
46
+ ).to(self.device)
47
+
48
+ def predict_depth(self, image: Image.Image) -> Image.Image:
49
+ """
50
+ Process a single PIL image and return a depth map as a uint16 PIL Image.
51
+ """
52
+ # Prepare inputs for the model
53
+ inputs = self.image_processor(images=image, return_tensors="pt")
54
+ # Move all tensors to the proper device with half precision
55
+ inputs = {k: v.to(self.device, dtype=torch.float16) for k, v in inputs.items()}
56
+
57
+ with torch.no_grad():
58
+ outputs = self.model(**inputs)
59
+ predicted_depth = outputs.predicted_depth
60
+
61
+ # Interpolate the predicted depth to the original image size
62
+ prediction = torch.nn.functional.interpolate(
63
+ predicted_depth.unsqueeze(1),
64
+ size=image.size[::-1], # PIL image size is (width, height), interpolate expects (height, width)
65
+ mode="bicubic",
66
+ align_corners=False,
67
+ )
68
+
69
+ # Convert the output tensor to a numpy array and save as a depth image
70
+ output = prediction.squeeze().cpu().numpy()
71
+ depth_image = DepthAnythingModel.save_depth(output)
72
+ return depth_image
73
+
74
+ def __call__(self, input_video: str, output_video: str = "depth.mp4") -> str:
75
+ """
76
+ Process a video file frame-by-frame to produce a depth-estimated video.
77
+ The output video is saved as an MP4 file.
78
+ """
79
+
80
+ log.info(f"Processing video: {input_video} to generate depth video: {output_video}")
81
+ assert os.path.exists(input_video)
82
+
83
+ cap = cv2.VideoCapture(input_video)
84
+ if not cap.isOpened():
85
+ print("Error: Cannot open video file.")
86
+ return
87
+
88
+ # Retrieve video properties
89
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
90
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
91
+ fps = cap.get(cv2.CAP_PROP_FPS)
92
+
93
+ depths = []
94
+ while True:
95
+ ret, frame = cap.read()
96
+ if not ret:
97
+ break
98
+
99
+ # Convert frame from BGR to RGB and then to PIL Image
100
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
101
+ inputs = self.image_processor(images=image, return_tensors="pt")
102
+ inputs = {k: v.to(self.device, dtype=torch.float16) for k, v in inputs.items()}
103
+
104
+ with torch.no_grad():
105
+ outputs = self.model(**inputs)
106
+ predicted_depth = outputs.predicted_depth
107
+
108
+ # For video processing, take the first output and interpolate to original size
109
+ prediction = torch.nn.functional.interpolate(
110
+ predicted_depth[0].unsqueeze(0).unsqueeze(0),
111
+ size=(frame_height, frame_width),
112
+ mode="bicubic",
113
+ align_corners=False,
114
+ )
115
+ depth = prediction.squeeze().cpu().numpy()
116
+ depths += [depth]
117
+ cap.release()
118
+
119
+ depths = np.stack(depths)
120
+ depths_normed = (depths - depths.min()) / (depths.max() - depths.min() + 1e-8) * 255.0
121
+ depths_normed = depths_normed.astype(np.uint8)
122
+
123
+ os.makedirs(os.path.dirname(output_video), exist_ok=True)
124
+ self.write_video(depths_normed, output_video, fps=fps)
125
+ return output_video
126
+
127
+ @staticmethod
128
+ def save_depth(output: np.ndarray) -> Image.Image:
129
+ """
130
+ Convert the raw depth output (float values) into a uint16 PIL Image.
131
+ """
132
+ depth_min = output.min()
133
+ depth_max = output.max()
134
+ max_val = (2**16) - 1 # Maximum value for uint16
135
+
136
+ if depth_max - depth_min > np.finfo("float").eps:
137
+ out_array = max_val * (output - depth_min) / (depth_max - depth_min)
138
+ else:
139
+ out_array = np.zeros_like(output)
140
+
141
+ formatted = out_array.astype("uint16")
142
+ depth_image = Image.fromarray(formatted, mode="I;16")
143
+ return depth_image
144
+
145
+ @staticmethod
146
+ def write_video(frames, output_path, fps=30):
147
+ with imageio.get_writer(output_path, fps=fps, macro_block_size=8) as writer:
148
+ for frame in frames:
149
+ if len(frame.shape) == 2: # single channel
150
+ frame = frame[:, :, None].repeat(3, axis=2)
151
+ writer.append_data(frame)
cosmos_transfer1/auxiliary/guardrail/README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cosmos Guardrail
2
+
3
+ This page outlines a set of tools to ensure content safety in Cosmos. For implementation details, please consult the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai).
4
+
5
+ ## Overview
6
+
7
+ Our guardrail system consists of two stages: pre-Guard and post-Guard.
8
+
9
+ Cosmos pre-Guard models are applied to text input, including input prompts and upsampled prompts.
10
+
11
+ * Blocklist: a keyword list checker for detecting harmful keywords
12
+ * Llama Guard 3: an LLM-based approach for blocking harmful prompts
13
+
14
+ Cosmos post-Guard models are applied to video frames generated by Cosmos models.
15
+
16
+ * Video Content Safety Filter: a classifier trained to distinguish between safe and unsafe video frames
17
+ * Face Blur Filter: a face detection and blurring module
cosmos_transfer1/auxiliary/guardrail/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_transfer1/auxiliary/guardrail/aegis/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_transfer1/auxiliary/guardrail/aegis/aegis.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import torch
20
+ from peft import PeftModel
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+ from cosmos_transfer1.auxiliary.guardrail.aegis.categories import UNSAFE_CATEGORIES
24
+ from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
25
+ from cosmos_transfer1.checkpoints import AEGIS_MODEL_CHECKPOINT, LLAMA_GUARD_MODEL_CHECKPOINT
26
+ from cosmos_transfer1.utils import log, misc
27
+
28
+ SAFE = misc.Color.green("SAFE")
29
+ UNSAFE = misc.Color.red("UNSAFE")
30
+
31
+
32
+ class Aegis(ContentSafetyGuardrail):
33
+ def __init__(
34
+ self,
35
+ checkpoint_dir: str,
36
+ device="cuda" if torch.cuda.is_available() else "cpu",
37
+ ) -> None:
38
+ self.checkpoint_dir = checkpoint_dir
39
+ self.device = device
40
+ self.dtype = torch.bfloat16
41
+ base_model_id = "meta-llama/LlamaGuard-7b"
42
+ aegis_adapter = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0"
43
+
44
+ base_model_dir = os.path.join(self.checkpoint_dir, base_model_id)
45
+ aegis_adapter_dir = os.path.join(self.checkpoint_dir, aegis_adapter)
46
+
47
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_id, cache_dir=base_model_dir)
48
+ self.tokenizer = AutoTokenizer.from_pretrained(base_model_id, cache_dir=base_model_dir)
49
+ self.model = PeftModel.from_pretrained(base_model, aegis_adapter, cache_dir=aegis_adapter_dir)
50
+
51
+ self.model.to(self.device, dtype=self.dtype).eval()
52
+
53
+ def get_moderation_prompt(self, user_prompt: str) -> str:
54
+ """Create the moderation prompt for the Aegis model."""
55
+ full_prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
56
+
57
+ <BEGIN UNSAFE CONTENT CATEGORIES>
58
+ {"".join([f"{category['instruction']}" for category in UNSAFE_CATEGORIES])}
59
+ <END UNSAFE CONTENT CATEGORIES>
60
+
61
+ <BEGIN CONVERSATION>
62
+
63
+ """
64
+ full_prompt += f"User: {user_prompt}\n\n"
65
+ full_prompt += """<END CONVERSATION>
66
+
67
+ Provide your safety assessment for User in the above conversation:
68
+ - First line must read 'safe' or 'unsafe'.
69
+ - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
70
+
71
+ return full_prompt
72
+
73
+ def get_aegis_block_message(self, moderation_output: str) -> str:
74
+ """Extract the blocked category and reason from the Aegis model output."""
75
+ block_msg = "Prompt blocked by Aegis."
76
+ try:
77
+ # Extract blocked category and reason
78
+ if len(moderation_output.split("\n")) >= 2:
79
+ moderation_output_list = moderation_output.split("\n")
80
+ blocked_category = int(moderation_output_list[1].lstrip("O")) - 1
81
+ if 0 <= blocked_category < len(UNSAFE_CATEGORIES):
82
+ blocked_reason = UNSAFE_CATEGORIES[blocked_category]["blocked_reason"]
83
+ blocked_category_name = UNSAFE_CATEGORIES[blocked_category]["category"]
84
+ block_msg = f"{blocked_category_name}: {blocked_reason}"
85
+ except Exception as e:
86
+ log.warning(f"Unable to extract blocked category and reason from Aegis output: {e}")
87
+ return block_msg
88
+
89
+ def filter_aegis_output(self, prompt: str) -> tuple[bool, str]:
90
+ """Filter the Aegis model output and return the safety status and message."""
91
+ full_prompt = self.get_moderation_prompt(prompt)
92
+ inputs = self.tokenizer([full_prompt], add_special_tokens=False, return_tensors="pt").to(self.device)
93
+ output = self.model.generate(**inputs, max_new_tokens=100, pad_token_id=self.tokenizer.eos_token_id)
94
+ prompt_len = inputs["input_ids"].shape[-1]
95
+ moderation_output = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
96
+
97
+ if "unsafe" in moderation_output.lower():
98
+ block_msg = self.get_aegis_block_message(moderation_output)
99
+ return False, block_msg
100
+ else:
101
+ return True, ""
102
+
103
+ def is_safe(self, prompt: str) -> tuple[bool, str]:
104
+ """Check if the input prompt is safe according to the Aegis model."""
105
+ try:
106
+ return self.filter_aegis_output(prompt)
107
+ except Exception as e:
108
+ log.error(f"Unexpected error occurred when running Aegis guardrail: {e}")
109
+ return True, "Unexpected error occurred when running Aegis guardrail."
110
+
111
+
112
+ def parse_args():
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
115
+ parser.add_argument(
116
+ "--checkpoint_dir",
117
+ type=str,
118
+ help="Path to the Aegis checkpoint folder",
119
+ default=DEFAULT_CHECKPOINT_DIR,
120
+ )
121
+ return parser.parse_args()
122
+
123
+
124
+ def main(args):
125
+ aegis = Aegis(checkpoint_dir=args.checkpoint_dir)
126
+ runner = GuardrailRunner(safety_models=[aegis])
127
+ with misc.timer("aegis safety check"):
128
+ safety, message = runner.run_safety_check(args.prompt)
129
+ log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
130
+ log.info(f"Message: {message}") if not safety else None
131
+
132
+
133
+ if __name__ == "__main__":
134
+ args = parse_args()
135
+ main(args)
cosmos_transfer1/auxiliary/guardrail/aegis/categories.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ VIOLENCE = {
17
+ "category": "violence",
18
+ "blocked_reason": """This prompt is blocked due to the presence of violent content.""",
19
+ "instruction": """O1: Violence.
20
+ Should not
21
+ - Help people plan or engage in violence.
22
+ - Encourage people to engage in violence.
23
+ - Contain unpleasant vivid violent imagery or violent description that arouses discomfort in the reader.
24
+ Can
25
+ - Provide information on violence.
26
+ - Discuss the topics of violence, and can discuss historical events involving violence.
27
+ """,
28
+ }
29
+
30
+ SEXUAL = {
31
+ "category": "sexual",
32
+ "blocked_reason": """This prompt is blocked due to the presence of sexual content.""",
33
+ "instruction": """O2: Sexual.
34
+ Should not
35
+ - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
36
+ - Engage or encourage sexual activity that is considered illegal.
37
+ Can
38
+ - Discuss topics related to sexuality and sexual education.
39
+ - Provide general information regarding sexual activities.
40
+ - Provide guidance on topics related to sexual health.
41
+ """,
42
+ }
43
+
44
+ CRIMINAL_PLANNING = {
45
+ "category": "criminal_planning",
46
+ "blocked_reason": """This prompt is blocked due to the presence of criminal planning/confessions content.""",
47
+ "instruction": """O3: Criminal Planning/Confessions.
48
+ Should not
49
+ - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
50
+ - Encourage people to engage in criminal activities or condone criminal behavior.
51
+ - Have a confession of crime that was committed or will be committed in future.
52
+ Can
53
+ - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
54
+ - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
55
+ """,
56
+ }
57
+
58
+ WEAPONS = {
59
+ "category": "weapons",
60
+ "blocked_reason": """This prompt is blocked due to the presence of weapons content.""",
61
+ "instruction": """O4: Guns and Illegal Weapons.
62
+ Should not
63
+ - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
64
+ - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
65
+ - Assist or encourage people to execute or own other illegal weapons for example biological agents, or chemical weapons, including instructions on how to create or use them.
66
+ Can
67
+ - Discuss firearms and the arguments for and against firearm ownership.
68
+ """,
69
+ }
70
+
71
+ SUBSTANCE_ABUSE = {
72
+ "category": "substance_abuse",
73
+ "blocked_reason": """This prompt is blocked due to the presence of substance abuse content.""",
74
+ "instruction": """O5: Controlled/Regulated Substances.
75
+ Should not
76
+ - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
77
+ - Assist or encourage people to create such substances when it is illegal to do so.
78
+ Can
79
+ - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
80
+ - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
81
+ """,
82
+ }
83
+
84
+ SUICIDE = {
85
+ "category": "suicide",
86
+ "blocked_reason": "This prompt is blocked due to the presence of suicide and self harm content.",
87
+ "instruction": """O6: Suicide and Self Harm.
88
+ Should not
89
+ - Encourage people to harm themselves.
90
+ - Romanticize or condone self-harm or suicide.
91
+ - Provide information regarding the methods of suicide or self-harm.
92
+ - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
93
+ - Contain confessions or intent of self harm
94
+ Can
95
+ - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
96
+ Should
97
+ - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.
98
+ """,
99
+ }
100
+
101
+ CSAM = {
102
+ "category": "csam",
103
+ "blocked_reason": "This prompt contains material related to Child Sexual Abuse",
104
+ "instruction": """O7: Sexual (minor).
105
+ Should not
106
+ - Assist, engage or encourage intention to engage in sexual activity with an underage/minor.
107
+ - Discuss or plan sexual acts, sharing of media of sexual nature, personal information or intention to meet in person for sexual activity with a minor.
108
+ - Should not include discussion where private information is shared with underage minors for purposes of acting on them and taking actions that are sexual in nature including meeting or sharing media.
109
+ - Should not discuss child abuse or engage, encourage, assist in sexual grooming/trafficking of children.
110
+ Can
111
+ - Provide information about sex education for use in schools.
112
+ """,
113
+ }
114
+
115
+ HATRED = {
116
+ "category": "hatred",
117
+ "blocked_reason": "This prompt is blocked due to the presence of hate speech content.",
118
+ "instruction": """O8: Hate/Identity Hate.
119
+ Should not
120
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their
121
+ race, color, ethnicity, religion, national origin, sexual orientation, gender, gender identity, or disability or caste.
122
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
123
+ - Should not discuss prejudice towards a group classified as protected according to the law of Germany, UK, USA or any other country.
124
+ - Should not contain slurs.
125
+ - Should not contain text that is distasteful, offensive, or hateful that causes others to feel grief, anger, or fear.
126
+ Can
127
+ - Provide information on discrimination.
128
+ - Discuss the topics of discrimination, and can discuss historical events involving discrimination.
129
+ - Discuss neutral statements referring to group identity and contextualized hate speech (that talks about hate but does not incite hate).
130
+ """,
131
+ }
132
+
133
+ HARASSMENT = {
134
+ "category": "harassment",
135
+ "blocked_reason": "This prompt is blocked due to the presence of harassment content.",
136
+ "instruction": """O10: Harassment.
137
+ Should not
138
+ - Have content that may be used to torment or annoy individuals in real life, or make harassment more likely to occur.
139
+ - Engage in harassment, misconduct and unlawful discrimination
140
+ - Single out an individual or marginalization and/or retaliation based on protected characteristics like Race, Color, Gender, Sex,
141
+ Sexual orientation, Gender identity, and gender expression, National origin, Ethnicity, Disability (including being regarded as disabled),
142
+ Religion, Age (40+), Pregnancy (including pregnancy, childbirth or related medical conditions), Genetic information, Military or veteran status, Citizenship status, Political activity, or affiliation,
143
+ - Taking or requesting statutorily protected leave Body characteristics.
144
+ - Promise to give a benefit, or a threat to retaliate or take an adverse action based on a response to harassment and misconduct.
145
+ - Bullying.
146
+ - Being derogatory and humiliating toward an individual but not necessarily protected characteristics under law.
147
+ - Making rude or insulting comments, demeaning, and objectifying terms toward an individual.
148
+ Can
149
+ - Discuss negative impacts of harassment and bullying based on protected characteristics like race, color, gender, sex, nationality, within a contextualized setting.
150
+ - Provide information on protected characteristics or refers, cites material that describes what protected characteristics are.
151
+ - Provide information or refer or cite material that informs what harassment is.
152
+ """,
153
+ }
154
+
155
+ THREAT = {
156
+ "category": "threat",
157
+ "blocked_reason": "This prompt is blocked due to the presence of threat content.",
158
+ "instruction": """O11: Threat.
159
+ Should not
160
+ - Discuss a communication of intent to inflict harm or loss on another person.
161
+ - Should not contain discussion of intention of intimidation which is a tactic used between conflicting parties to make the other timid or psychologically insecure for coercion or control.
162
+ - Should not contain intention of knowingly putting another person in fear of bodily injury.
163
+ - Should not assist or plan injury, physical or mental damage, act or instance of injury, or material or any kind of loss to a person.
164
+ - Should not contain threats that are forbidden by law as intent to obtain a monetary advantage or to compel a person to act against their will.
165
+ Can
166
+ - Provide information about threat is and the meaning of threat.
167
+ """,
168
+ }
169
+
170
+ PROFANITY = {
171
+ "category": "profanity",
172
+ "blocked_reason": "This prompt is blocked due to the presence of profanity content.",
173
+ "instruction": """O12: Profanity.
174
+ Should not
175
+ - Contain swear words, curse words, or other obscene or profane language
176
+ - This includes offensive words used without any intention to act on them or without any explicit targeted recipient
177
+ """,
178
+ }
179
+
180
+ UNSAFE_CATEGORIES = [
181
+ VIOLENCE,
182
+ SEXUAL,
183
+ CRIMINAL_PLANNING,
184
+ WEAPONS,
185
+ SUBSTANCE_ABUSE,
186
+ SUICIDE,
187
+ CSAM,
188
+ HATRED,
189
+ HARASSMENT,
190
+ THREAT,
191
+ PROFANITY,
192
+ ]
cosmos_transfer1/auxiliary/guardrail/blocklist/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_transfer1/auxiliary/guardrail/blocklist/blocklist.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+ import re
19
+ import string
20
+ from difflib import SequenceMatcher
21
+
22
+ import nltk
23
+ from better_profanity import profanity
24
+
25
+ from cosmos_transfer1.auxiliary.guardrail.blocklist.utils import read_keyword_list_from_dir, to_ascii
26
+ from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
27
+ from cosmos_transfer1.utils import log, misc
28
+
29
+ CENSOR = misc.Color.red("*")
30
+
31
+
32
+ class Blocklist(ContentSafetyGuardrail):
33
+ def __init__(
34
+ self,
35
+ checkpoint_dir: str,
36
+ guardrail_partial_match_min_chars: int = 6,
37
+ guardrail_partial_match_letter_count: float = 0.4,
38
+ ) -> None:
39
+ self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/blocklist")
40
+ nltk.data.path.append(os.path.join(self.checkpoint_dir, "nltk_data"))
41
+ self.lemmatizer = nltk.WordNetLemmatizer()
42
+ self.profanity = profanity
43
+ self.guardrail_partial_match_min_chars = guardrail_partial_match_min_chars
44
+ self.guardrail_partial_match_letter_count = guardrail_partial_match_letter_count
45
+
46
+ # Load blocklist and whitelist keywords
47
+ self.blocklist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "custom"))
48
+ self.whitelist_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "whitelist"))
49
+ self.exact_match_words = read_keyword_list_from_dir(os.path.join(self.checkpoint_dir, "exact_match"))
50
+
51
+ self.profanity.load_censor_words(custom_words=self.blocklist_words, whitelist_words=self.whitelist_words)
52
+ log.debug(f"Loaded {len(self.blocklist_words)} words/phrases from blocklist")
53
+ log.debug(f"Whitelisted {len(self.whitelist_words)} words/phrases from whitelist")
54
+ log.debug(f"Loaded {len(self.exact_match_words)} exact match words/phrases from blocklist")
55
+
56
+ def uncensor_whitelist(self, input_prompt: str, censored_prompt: str) -> str:
57
+ """Explicitly uncensor words that are in the whitelist."""
58
+ input_words = input_prompt.split()
59
+ censored_words = censored_prompt.split()
60
+ whitelist_words = set(self.whitelist_words)
61
+ for i, token in enumerate(input_words):
62
+ if token.strip(string.punctuation).lower() in whitelist_words:
63
+ censored_words[i] = token
64
+ censored_prompt = " ".join(censored_words)
65
+ return censored_prompt
66
+
67
+ def censor_prompt(self, input_prompt: str) -> tuple[bool, str]:
68
+ """Censor the prompt using the blocklist with better-profanity fuzzy matching.
69
+
70
+ Args:
71
+ input_prompt: input prompt to censor
72
+
73
+ Returns:
74
+ bool: True if the prompt is blocked, False otherwise
75
+ str: A message indicating why the prompt was blocked
76
+ """
77
+ censored_prompt = self.profanity.censor(input_prompt, censor_char=CENSOR)
78
+ # Uncensor whitelisted words that were censored from blocklist fuzzy matching
79
+ censored_prompt = self.uncensor_whitelist(input_prompt, censored_prompt)
80
+ if CENSOR in censored_prompt:
81
+ return True, f"Prompt blocked by censorship: Censored Prompt: {censored_prompt}"
82
+ return False, ""
83
+
84
+ @staticmethod
85
+ def check_partial_match(
86
+ normalized_prompt: str, normalized_word: str, guardrail_partial_match_letter_count: float
87
+ ) -> tuple[bool, str]:
88
+ """
89
+ Check robustly if normalized word and the matching target have a difference of up to guardrail_partial_match_letter_count characters.
90
+
91
+ Args:
92
+ normalized_prompt: a string with many words
93
+ normalized_word: a string with one or multiple words, its length is smaller than normalized_prompt
94
+ guardrail_partial_match_letter_count: maximum allowed difference in characters (float to allow partial characters)
95
+
96
+ Returns:
97
+ bool: True if a match is found, False otherwise
98
+ str: A message indicating why the prompt was blocked
99
+ """
100
+ prompt_words = normalized_prompt.split()
101
+ word_length = len(normalized_word.split())
102
+ max_similarity_ratio = (len(normalized_word) - float(guardrail_partial_match_letter_count)) / float(
103
+ len(normalized_word)
104
+ )
105
+
106
+ for i in range(len(prompt_words) - word_length + 1):
107
+ # Extract a substring from the prompt with the same number of words as the normalized_word
108
+ substring = " ".join(prompt_words[i : i + word_length])
109
+ similarity_ratio = SequenceMatcher(None, substring, normalized_word).ratio()
110
+ if similarity_ratio >= max_similarity_ratio:
111
+ return (
112
+ True,
113
+ f"Prompt blocked by partial match blocklist: Prompt: {normalized_prompt}, Partial Match Word: {normalized_word}",
114
+ )
115
+
116
+ return False, ""
117
+
118
+ @staticmethod
119
+ def check_against_whole_word_blocklist(
120
+ prompt: str,
121
+ blocklist: list[str],
122
+ guardrail_partial_match_min_chars: int = 6,
123
+ guardrail_partial_match_letter_count: float = 0.4,
124
+ ) -> bool:
125
+ """
126
+ Check if the prompt contains any whole words from the blocklist.
127
+ The match is case insensitive and robust to multiple spaces between words.
128
+
129
+ Args:
130
+ prompt: input prompt to check
131
+ blocklist: list of words to check against
132
+ guardrail_partial_match_min_chars: minimum number of characters in a word to check for partial match
133
+ guardrail_partial_match_letter_count: maximum allowed difference in characters for partial match
134
+
135
+ Returns:
136
+ bool: True if a match is found, False otherwise
137
+ str: A message indicating why the prompt was blocked
138
+ """
139
+ # Normalize spaces and convert to lowercase
140
+ normalized_prompt = re.sub(r"\s+", " ", prompt).strip().lower()
141
+
142
+ for word in blocklist:
143
+ # Normalize spaces and convert to lowercase for each blocklist word
144
+ normalized_word = re.sub(r"\s+", " ", word).strip().lower()
145
+
146
+ # Use word boundaries to ensure whole word match
147
+ if re.search(r"\b" + re.escape(normalized_word) + r"\b", normalized_prompt):
148
+ return True, f"Prompt blocked by exact match blocklist: Prompt: {prompt}, Exact Match Word: {word}"
149
+
150
+ # Check for partial match if the word is long enough
151
+ if len(normalized_word) >= guardrail_partial_match_min_chars:
152
+ match, message = Blocklist.check_partial_match(
153
+ normalized_prompt, normalized_word, guardrail_partial_match_letter_count
154
+ )
155
+ if match:
156
+ return True, message
157
+
158
+ return False, ""
159
+
160
+ def is_safe(self, input_prompt: str = "") -> tuple[bool, str]:
161
+ """Check if the input prompt is safe using the blocklist."""
162
+ # Check if the input is empty
163
+ if not input_prompt:
164
+ return False, "Input is empty"
165
+ input_prompt = to_ascii(input_prompt)
166
+
167
+ # Check full sentence for censored words
168
+ censored, message = self.censor_prompt(input_prompt)
169
+ if censored:
170
+ return False, message
171
+
172
+ # Check lemmatized words for censored words
173
+ tokens = nltk.word_tokenize(input_prompt)
174
+ lemmas = [self.lemmatizer.lemmatize(token) for token in tokens]
175
+ lemmatized_prompt = " ".join(lemmas)
176
+ censored, message = self.censor_prompt(lemmatized_prompt)
177
+ if censored:
178
+ return False, message
179
+
180
+ # Check for exact match blocklist words
181
+ censored, message = self.check_against_whole_word_blocklist(
182
+ input_prompt,
183
+ self.exact_match_words,
184
+ self.guardrail_partial_match_min_chars,
185
+ self.guardrail_partial_match_letter_count,
186
+ )
187
+ if censored:
188
+ return False, message
189
+
190
+ # If all these checks pass, the input is safe
191
+ return True, "Input is safe"
192
+
193
+
194
+ def parse_args():
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
197
+ parser.add_argument(
198
+ "--checkpoint_dir",
199
+ type=str,
200
+ help="Path to the Blocklist checkpoint folder",
201
+ )
202
+ return parser.parse_args()
203
+
204
+
205
+ def main(args):
206
+ blocklist = Blocklist(checkpoint_dir=args.checkpoint_dir)
207
+ runner = GuardrailRunner(safety_models=[blocklist])
208
+ with misc.timer("blocklist safety check"):
209
+ safety, message = runner.run_safety_check(args.prompt)
210
+ log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
211
+ log.info(f"Message: {message}") if not safety else None
212
+
213
+
214
+ if __name__ == "__main__":
215
+ args = parse_args()
216
+ main(args)
cosmos_transfer1/auxiliary/guardrail/blocklist/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+
19
+ from cosmos_transfer1.utils import log
20
+
21
+
22
+ def read_keyword_list_from_dir(folder_path: str) -> list[str]:
23
+ """Read keyword list from all files in a folder."""
24
+ output_list = []
25
+ file_list = []
26
+ # Get list of files in the folder
27
+ for file in os.listdir(folder_path):
28
+ if os.path.isfile(os.path.join(folder_path, file)):
29
+ file_list.append(file)
30
+
31
+ # Process each file
32
+ for file in file_list:
33
+ file_path = os.path.join(folder_path, file)
34
+ try:
35
+ with open(file_path, "r") as f:
36
+ output_list.extend([line.strip() for line in f.readlines()])
37
+ except Exception as e:
38
+ log.error(f"Error reading file {file}: {str(e)}")
39
+
40
+ return output_list
41
+
42
+
43
+ def to_ascii(prompt: str) -> str:
44
+ """Convert prompt to ASCII."""
45
+ return re.sub(r"[^\x00-\x7F]+", " ", prompt)
cosmos_transfer1/auxiliary/guardrail/common/__init__.py ADDED
File without changes
cosmos_transfer1/auxiliary/guardrail/common/core.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Tuple
17
+
18
+ import numpy as np
19
+
20
+ from cosmos_transfer1.utils import log
21
+
22
+
23
+ class ContentSafetyGuardrail:
24
+ def is_safe(self, **kwargs) -> Tuple[bool, str]:
25
+ raise NotImplementedError("Child classes must implement the is_safe method")
26
+
27
+
28
+ class PostprocessingGuardrail:
29
+ def postprocess(self, frames: np.ndarray) -> np.ndarray:
30
+ raise NotImplementedError("Child classes must implement the postprocess method")
31
+
32
+
33
+ class GuardrailRunner:
34
+ def __init__(
35
+ self,
36
+ safety_models: list[ContentSafetyGuardrail] | None = None,
37
+ generic_block_msg: str = "",
38
+ generic_safe_msg: str = "",
39
+ postprocessors: list[PostprocessingGuardrail] | None = None,
40
+ ):
41
+ self.safety_models = safety_models
42
+ self.generic_block_msg = generic_block_msg
43
+ self.generic_safe_msg = generic_safe_msg if generic_safe_msg else "Prompt is safe"
44
+ self.postprocessors = postprocessors
45
+
46
+ def run_safety_check(self, input: Any) -> Tuple[bool, str]:
47
+ """Run the safety check on the input."""
48
+ if not self.safety_models:
49
+ log.warning("No safety models found, returning safe")
50
+ return True, self.generic_safe_msg
51
+
52
+ for guardrail in self.safety_models:
53
+ guardrail_name = str(guardrail.__class__.__name__).upper()
54
+ log.debug(f"Running guardrail: {guardrail_name}")
55
+ safe, message = guardrail.is_safe(input)
56
+ if not safe:
57
+ reasoning = self.generic_block_msg if self.generic_block_msg else f"{guardrail_name}: {message}"
58
+ return False, reasoning
59
+ return True, self.generic_safe_msg
60
+
61
+ def postprocess(self, frames: np.ndarray) -> np.ndarray:
62
+ """Run the postprocessing on the video frames."""
63
+ if not self.postprocessors:
64
+ log.warning("No postprocessors found, returning original frames")
65
+ return frames
66
+
67
+ for guardrail in self.postprocessors:
68
+ guardrail_name = str(guardrail.__class__.__name__).upper()
69
+ log.debug(f"Running guardrail: {guardrail_name}")
70
+ frames = guardrail.postprocess(frames)
71
+ return frames
cosmos_transfer1/auxiliary/guardrail/common/io_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ from dataclasses import dataclass
18
+
19
+ import imageio
20
+ import numpy as np
21
+
22
+ from cosmos_transfer1.utils import log
23
+
24
+
25
+ @dataclass
26
+ class VideoData:
27
+ frames: np.ndarray # Shape: [B, H, W, C]
28
+ fps: int
29
+ duration: int # in seconds
30
+
31
+
32
+ def get_video_filepaths(input_dir: str) -> list[str]:
33
+ """Get a list of filepaths for all videos in the input directory."""
34
+ paths = glob.glob(f"{input_dir}/**/*.mp4", recursive=True)
35
+ paths += glob.glob(f"{input_dir}/**/*.avi", recursive=True)
36
+ paths += glob.glob(f"{input_dir}/**/*.mov", recursive=True)
37
+ paths = sorted(paths)
38
+ log.debug(f"Found {len(paths)} videos")
39
+ return paths
40
+
41
+
42
+ def read_video(filepath: str) -> VideoData:
43
+ """Read a video file and extract its frames and metadata."""
44
+ try:
45
+ reader = imageio.get_reader(filepath, "ffmpeg")
46
+ except Exception as e:
47
+ raise ValueError(f"Failed to read video file: {filepath}") from e
48
+
49
+ # Extract metadata from the video file
50
+ try:
51
+ metadata = reader.get_meta_data()
52
+ fps = metadata.get("fps")
53
+ duration = metadata.get("duration")
54
+ except Exception as e:
55
+ reader.close()
56
+ raise ValueError(f"Failed to extract metadata from video file: {filepath}") from e
57
+
58
+ # Extract frames from the video file
59
+ try:
60
+ frames = np.array([frame for frame in reader])
61
+ except Exception as e:
62
+ raise ValueError(f"Failed to extract frames from video file: {filepath}") from e
63
+ finally:
64
+ reader.close()
65
+
66
+ return VideoData(frames=frames, fps=fps, duration=duration)
67
+
68
+
69
+ def save_video(filepath: str, frames: np.ndarray, fps: int) -> None:
70
+ """Save a video file from a sequence of frames."""
71
+ try:
72
+ writer = imageio.get_writer(filepath, fps=fps, macro_block_size=1)
73
+ for frame in frames:
74
+ writer.append_data(frame)
75
+ except Exception as e:
76
+ raise ValueError(f"Failed to save video file to {filepath}") from e
77
+ finally:
78
+ writer.close()
cosmos_transfer1/auxiliary/guardrail/common/presets.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import numpy as np
19
+
20
+ from cosmos_transfer1.auxiliary.guardrail.blocklist.blocklist import Blocklist
21
+ from cosmos_transfer1.auxiliary.guardrail.common.core import GuardrailRunner
22
+ from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter
23
+ from cosmos_transfer1.auxiliary.guardrail.llamaGuard3.llamaGuard3 import LlamaGuard3
24
+ from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.video_content_safety_filter import (
25
+ VideoContentSafetyFilter,
26
+ )
27
+ from cosmos_transfer1.utils import log
28
+
29
+
30
+ def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
31
+ """Create the text guardrail runner."""
32
+ return GuardrailRunner(safety_models=[Blocklist(checkpoint_dir), LlamaGuard3(checkpoint_dir)])
33
+
34
+
35
+ def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
36
+ """Create the video guardrail runner."""
37
+ return GuardrailRunner(
38
+ safety_models=[VideoContentSafetyFilter(checkpoint_dir)],
39
+ postprocessors=[RetinaFaceFilter(checkpoint_dir)],
40
+ )
41
+
42
+
43
+ def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
44
+ """Run the text guardrail on the prompt, checking for content safety.
45
+
46
+ Args:
47
+ prompt: The text prompt.
48
+ guardrail_runner: The text guardrail runner.
49
+
50
+ Returns:
51
+ bool: Whether the prompt is safe.
52
+ """
53
+ is_safe, message = guardrail_runner.run_safety_check(prompt)
54
+ if not is_safe:
55
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
56
+ return is_safe
57
+
58
+
59
+ def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
60
+ """Run the video guardrail on the frames, checking for content safety and applying face blur.
61
+
62
+ Args:
63
+ frames: The frames of the generated video.
64
+ guardrail_runner: The video guardrail runner.
65
+
66
+ Returns:
67
+ The processed frames if safe, otherwise None.
68
+ """
69
+ is_safe, message = guardrail_runner.run_safety_check(frames)
70
+ if not is_safe:
71
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
72
+ return None
73
+
74
+ frames = guardrail_runner.postprocess(frames)
75
+ return frames
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/blur_utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import cv2
17
+ import numpy as np
18
+
19
+
20
+ def pixelate_face(face_img: np.ndarray, blocks: int = 5) -> np.ndarray:
21
+ """
22
+ Pixelate a face region by reducing resolution and then upscaling.
23
+
24
+ Args:
25
+ face_img: Face region to pixelate
26
+ blocks: Number of blocks to divide the face into (in each dimension)
27
+
28
+ Returns:
29
+ Pixelated face region
30
+ """
31
+ h, w = face_img.shape[:2]
32
+ # Shrink the image and scale back up to create pixelation effect
33
+ temp = cv2.resize(face_img, (blocks, blocks), interpolation=cv2.INTER_LINEAR)
34
+ pixelated = cv2.resize(temp, (w, h), interpolation=cv2.INTER_NEAREST)
35
+ return pixelated
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/face_blur_filter.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import numpy as np
20
+ import torch
21
+ from retinaface.data import cfg_re50
22
+ from retinaface.layers.functions.prior_box import PriorBox
23
+ from retinaface.models.retinaface import RetinaFace
24
+ from torch.utils.data import DataLoader, TensorDataset
25
+ from tqdm import tqdm
26
+
27
+ from cosmos_transfer1.auxiliary.guardrail.common.core import GuardrailRunner, PostprocessingGuardrail
28
+ from cosmos_transfer1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video, save_video
29
+ from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.blur_utils import pixelate_face
30
+ from cosmos_transfer1.auxiliary.guardrail.face_blur_filter.retinaface_utils import (
31
+ decode_batch,
32
+ filter_detected_boxes,
33
+ load_model,
34
+ )
35
+ from cosmos_transfer1.utils import log, misc
36
+
37
+ # RetinaFace model constants from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
38
+ TOP_K = 5_000
39
+ KEEP_TOP_K = 750
40
+ NMS_THRESHOLD = 0.4
41
+
42
+
43
+ class RetinaFaceFilter(PostprocessingGuardrail):
44
+ def __init__(
45
+ self,
46
+ checkpoint_dir: str,
47
+ batch_size: int = 1,
48
+ confidence_threshold: float = 0.7,
49
+ device="cuda" if torch.cuda.is_available() else "cpu",
50
+ ) -> None:
51
+ """
52
+ Initialize the RetinaFace model for face detection and blurring.
53
+
54
+ Args:
55
+ checkpoint: Path to the RetinaFace checkpoint file
56
+ batch_size: Batch size for RetinaFace inference and processing
57
+ confidence_threshold: Minimum confidence score to consider a face detection
58
+ """
59
+ self.checkpoint = f"{checkpoint_dir}/nvidia/Cosmos-Guardrail1/face_blur_filter/Resnet50_Final.pth"
60
+ self.cfg = cfg_re50
61
+ self.batch_size = batch_size
62
+ self.confidence_threshold = confidence_threshold
63
+ self.device = device
64
+ self.dtype = torch.float32
65
+
66
+ # Disable loading ResNet pretrained weights
67
+ self.cfg["pretrain"] = False
68
+ self.net = RetinaFace(cfg=self.cfg, phase="test")
69
+ cpu = self.device == "cpu"
70
+
71
+ # Load from RetinaFace pretrained checkpoint
72
+ self.net = load_model(self.net, self.checkpoint, cpu)
73
+ self.net.to(self.device, dtype=self.dtype).eval()
74
+
75
+ def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor:
76
+ """Preprocess a sequence of frames for face detection.
77
+
78
+ Args:
79
+ frames: Input frames
80
+
81
+ Returns:
82
+ Preprocessed frames tensor
83
+ """
84
+ with torch.no_grad():
85
+ frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) # Shape: [T, H, W, C]
86
+ frames_tensor = frames_tensor.permute(0, 3, 1, 2) # Shape: [T, C, H, W]
87
+ frames_tensor = frames_tensor[:, [2, 1, 0], :, :] # RGB to BGR to match RetinaFace model input
88
+ means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1)
89
+ frames_tensor = frames_tensor - means # Subtract mean BGR values for each channel
90
+ return frames_tensor
91
+
92
+ def blur_detected_faces(
93
+ self,
94
+ frames: np.ndarray,
95
+ batch_loc: torch.Tensor,
96
+ batch_conf: torch.Tensor,
97
+ prior_data: torch.Tensor,
98
+ scale: torch.Tensor,
99
+ min_size: tuple[int] = (20, 20),
100
+ ) -> list[np.ndarray]:
101
+ """Blur detected faces in a batch of frames using RetinaFace predictions.
102
+
103
+ Args:
104
+ frames: Input frames
105
+ batch_loc: Batched location predictions
106
+ batch_conf: Batched confidence scores
107
+ prior_data: Prior boxes for the video
108
+ scale: Scale factor for resizing detections
109
+ min_size: Minimum size of a detected face region in pixels
110
+
111
+ Returns:
112
+ Processed frames with pixelated faces
113
+ """
114
+ with torch.no_grad():
115
+ batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"])
116
+ batch_boxes = batch_boxes * scale
117
+
118
+ blurred_frames = []
119
+ for i, boxes in enumerate(batch_boxes):
120
+ boxes = boxes.detach().cpu().numpy()
121
+ scores = batch_conf[i, :, 1].detach().cpu().numpy()
122
+
123
+ filtered_boxes = filter_detected_boxes(
124
+ boxes,
125
+ scores,
126
+ confidence_threshold=self.confidence_threshold,
127
+ nms_threshold=NMS_THRESHOLD,
128
+ top_k=TOP_K,
129
+ keep_top_k=KEEP_TOP_K,
130
+ )
131
+
132
+ frame = frames[i]
133
+ for box in filtered_boxes:
134
+ x1, y1, x2, y2 = map(int, box)
135
+ # Ignore bounding boxes smaller than the minimum size
136
+ if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]:
137
+ continue
138
+ max_h, max_w = frame.shape[:2]
139
+ face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)]
140
+ blurred_face = pixelate_face(face_roi)
141
+ frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face
142
+ blurred_frames.append(frame)
143
+
144
+ return blurred_frames
145
+
146
+ def postprocess(self, frames: np.ndarray) -> np.ndarray:
147
+ """Blur faces in a sequence of frames.
148
+
149
+ Args:
150
+ frames: Input frames
151
+
152
+ Returns:
153
+ Processed frames with pixelated faces
154
+ """
155
+ # Create dataset and dataloader
156
+ frames_tensor = self.preprocess_frames(frames)
157
+ dataset = TensorDataset(frames_tensor)
158
+ dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
159
+ processed_frames, processed_batches = [], []
160
+
161
+ prior_data, scale = None, None
162
+ for i, batch in enumerate(dataloader):
163
+ batch = batch[0]
164
+ h, w = batch.shape[-2:] # Batch shape: [C, H, W]
165
+
166
+ with torch.no_grad():
167
+ # Generate priors for the video
168
+ if prior_data is None:
169
+ priorbox = PriorBox(self.cfg, image_size=(h, w))
170
+ priors = priorbox.forward()
171
+ priors = priors.to(self.device, dtype=self.dtype)
172
+ prior_data = priors.data
173
+
174
+ # Get scale for resizing detections
175
+ if scale is None:
176
+ scale = torch.Tensor([w, h, w, h])
177
+ scale = scale.to(self.device, dtype=self.dtype)
178
+
179
+ batch_loc, batch_conf, _ = self.net(batch)
180
+
181
+ # Blur detected faces in each batch of frames
182
+ start_idx = i * self.batch_size
183
+ end_idx = min(start_idx + self.batch_size, len(frames))
184
+ processed_batches.append(
185
+ self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale)
186
+ )
187
+
188
+ processed_frames = [frame for batch in processed_batches for frame in batch]
189
+ return np.array(processed_frames)
190
+
191
+
192
+ def parse_args():
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos")
195
+ parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos")
196
+ parser.add_argument(
197
+ "--checkpoint",
198
+ type=str,
199
+ help="Path to the RetinaFace checkpoint file",
200
+ )
201
+ return parser.parse_args()
202
+
203
+
204
+ def main(args):
205
+ filepaths = get_video_filepaths(args.input_dir)
206
+ if not filepaths:
207
+ log.error(f"No video files found in directory: {args.input_dir}")
208
+ return
209
+
210
+ face_blur = RetinaFaceFilter(checkpoint=args.checkpoint)
211
+ postprocessing_runner = GuardrailRunner(postprocessors=[face_blur])
212
+ os.makedirs(args.output_dir, exist_ok=True)
213
+
214
+ for filepath in tqdm(filepaths):
215
+ video_data = read_video(filepath)
216
+ with misc.timer("face blur filter"):
217
+ frames = postprocessing_runner.postprocess(video_data.frames)
218
+
219
+ output_path = os.path.join(args.output_dir, os.path.basename(filepath))
220
+ save_video(output_path, frames, video_data.fps)
221
+
222
+
223
+ if __name__ == "__main__":
224
+ args = parse_args()
225
+ main(args)
cosmos_transfer1/auxiliary/guardrail/face_blur_filter/retinaface_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+ import torch
18
+ from retinaface.utils.nms.py_cpu_nms import py_cpu_nms
19
+
20
+ from cosmos_transfer1.utils import log
21
+
22
+
23
+ # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
24
+ def filter_detected_boxes(boxes, scores, confidence_threshold, nms_threshold, top_k, keep_top_k):
25
+ """Filter boxes based on confidence score and remove overlapping boxes using NMS."""
26
+ # Keep detections with confidence above threshold
27
+ inds = np.where(scores > confidence_threshold)[0]
28
+ boxes = boxes[inds]
29
+ scores = scores[inds]
30
+
31
+ # Sort by confidence and keep top K detections
32
+ order = scores.argsort()[::-1][:top_k]
33
+ boxes = boxes[order]
34
+ scores = scores[order]
35
+
36
+ # Run non-maximum-suppression (NMS) to remove overlapping boxes
37
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
38
+ keep = py_cpu_nms(dets, nms_threshold)
39
+ dets = dets[keep, :]
40
+ dets = dets[:keep_top_k, :]
41
+ boxes = dets[:, :-1]
42
+ return boxes
43
+
44
+
45
+ # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/utils/box_utils.py to handle batched inputs
46
+ def decode_batch(loc, priors, variances):
47
+ """Decode batched locations from predictions using priors and variances.
48
+
49
+ Args:
50
+ loc (tensor): Batched location predictions for loc layers.
51
+ Shape: [batch_size, num_priors, 4]
52
+ priors (tensor): Prior boxes in center-offset form.
53
+ Shape: [num_priors, 4]
54
+ variances: (list[float]): Variances of prior boxes.
55
+
56
+ Return:
57
+ Decoded batched bounding box predictions
58
+ Shape: [batch_size, num_priors, 4]
59
+ """
60
+ batch_size = loc.size(0)
61
+ priors = priors.unsqueeze(0).expand(batch_size, -1, -1)
62
+
63
+ boxes = torch.cat(
64
+ (
65
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
66
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]),
67
+ ),
68
+ dim=2,
69
+ )
70
+
71
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
72
+ boxes[:, :, 2:] += boxes[:, :, :2]
73
+ return boxes
74
+
75
+
76
+ # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
77
+ def _check_keys(model, pretrained_state_dict):
78
+ ckpt_keys = set(pretrained_state_dict.keys())
79
+ model_keys = set(model.state_dict().keys())
80
+ used_pretrained_keys = model_keys & ckpt_keys
81
+ unused_pretrained_keys = ckpt_keys - model_keys
82
+ missing_keys = model_keys - ckpt_keys
83
+ log.debug("Missing keys:{}".format(len(missing_keys)))
84
+ log.debug("Unused checkpoint keys:{}".format(len(unused_pretrained_keys)))
85
+ log.debug("Used keys:{}".format(len(used_pretrained_keys)))
86
+ assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
87
+ return True
88
+
89
+
90
+ # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
91
+ def _remove_prefix(state_dict, prefix):
92
+ """Old version of the model is stored with all names of parameters sharing common prefix 'module.'"""
93
+ log.debug("Removing prefix '{}'".format(prefix))
94
+
95
+ def f(x):
96
+ return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
97
+
98
+ return {f(key): value for key, value in state_dict.items()}
99
+
100
+
101
+ # Adapted from https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py
102
+ def load_model(model, pretrained_path, load_to_cpu):
103
+ log.debug("Loading pretrained model from {}".format(pretrained_path))
104
+ if load_to_cpu:
105
+ pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage, weights_only=True)
106
+ else:
107
+ device = torch.cuda.current_device()
108
+ pretrained_dict = torch.load(
109
+ pretrained_path, map_location=lambda storage, loc: storage.cuda(device), weights_only=True
110
+ )
111
+ if "state_dict" in pretrained_dict.keys():
112
+ pretrained_dict = _remove_prefix(pretrained_dict["state_dict"], "module.")
113
+ else:
114
+ pretrained_dict = _remove_prefix(pretrained_dict, "module.")
115
+ _check_keys(model, pretrained_dict)
116
+ model.load_state_dict(pretrained_dict, strict=False)
117
+ return model
cosmos_transfer1/auxiliary/guardrail/llamaGuard3/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_transfer1/auxiliary/guardrail/llamaGuard3/categories.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ UNSAFE_CATEGORIES = {
17
+ "S1": "Violent Crimes.",
18
+ "S2": "Non-Violent Crimes.",
19
+ "S3": "Sex Crimes.",
20
+ "S4": "Child Exploitation.",
21
+ "S5": "Defamation.",
22
+ "S6": "Specialized Advice.",
23
+ "S7": "Privacy.",
24
+ "S8": "Intellectual Property.",
25
+ "S9": "Indiscriminate Weapons.",
26
+ "S10": "Hate.",
27
+ "S11": "Self-Harm.",
28
+ "S12": "Sexual Content.",
29
+ "S13": "Elections.",
30
+ "s14": "Code Interpreter Abuse.",
31
+ }
cosmos_transfer1/auxiliary/guardrail/llamaGuard3/llamaGuard3.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import torch
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+
22
+ from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
23
+ from cosmos_transfer1.auxiliary.guardrail.llamaGuard3.categories import UNSAFE_CATEGORIES
24
+ from cosmos_transfer1.utils import log, misc
25
+
26
+ SAFE = misc.Color.green("SAFE")
27
+ UNSAFE = misc.Color.red("UNSAFE")
28
+
29
+
30
+ class LlamaGuard3(ContentSafetyGuardrail):
31
+ def __init__(
32
+ self,
33
+ checkpoint_dir: str,
34
+ device="cuda" if torch.cuda.is_available() else "cpu",
35
+ ) -> None:
36
+ self.checkpoint_dir = checkpoint_dir
37
+ self.device = device
38
+ self.dtype = torch.bfloat16
39
+
40
+ model_id = "meta-llama/Llama-Guard-3-8B"
41
+ model_dir = os.path.join(self.checkpoint_dir, model_id)
42
+
43
+ self.model = AutoModelForCausalLM.from_pretrained(model_dir)
44
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
45
+
46
+ self.model.to(self.device, dtype=self.dtype).eval()
47
+
48
+ def get_llamaGuard3_block_message(self, moderation_output: str) -> str:
49
+ """Extract the blocked category from the Llama Guard 3 model output."""
50
+ block_msg = "Prompt blocked by Llama Guard 3."
51
+ try:
52
+ lines = moderation_output.splitlines()
53
+ categories_detected = []
54
+ for line in lines[1:]:
55
+ line_stripped = line.split("<|eot_id|>")[0].strip()
56
+ for catagory in line_stripped.split(","):
57
+ catagory = catagory.strip()
58
+ if catagory not in UNSAFE_CATEGORIES:
59
+ log.warning(f"Unrecognized category from moderation output: {catagory}")
60
+ else:
61
+ categories_detected.append(catagory)
62
+ if len(categories_detected) > 0:
63
+ blocked_catagories = ", ".join([UNSAFE_CATEGORIES[catagory][:-1] for catagory in categories_detected])
64
+ block_msg = f"{block_msg} Violations: {blocked_catagories}."
65
+ except Exception as e:
66
+ log.warning(f"Unable to extract blocked category from Llama Guard 3 output: {e}")
67
+ return block_msg
68
+
69
+ def filter_llamaGuard3_output(self, prompt: str) -> tuple[bool, str]:
70
+ """Filter the Llama Guard 3 model output and return the safety status and message."""
71
+ conversation = [{"role": "user", "content": prompt}]
72
+ input_ids = self.tokenizer.apply_chat_template(
73
+ conversation, categories=UNSAFE_CATEGORIES, return_tensors="pt"
74
+ ).to("cuda")
75
+ prompt_len = input_ids.shape[1]
76
+ output = self.model.generate(
77
+ input_ids=input_ids,
78
+ max_new_tokens=100,
79
+ return_dict_in_generate=True,
80
+ pad_token_id=0,
81
+ )
82
+ generated_tokens = output.sequences[:, prompt_len:]
83
+ moderation_output = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=False).strip()
84
+
85
+ if "unsafe" in moderation_output.lower():
86
+ block_msg = self.get_llamaGuard3_block_message(moderation_output)
87
+ return False, block_msg
88
+ else:
89
+ return True, ""
90
+
91
+ def is_safe(self, prompt: str) -> tuple[bool, str]:
92
+ """Check if the input prompt is safe according to the Llama Guard 3 model."""
93
+ try:
94
+ return self.filter_llamaGuard3_output(prompt)
95
+ except Exception as e:
96
+ log.error(f"Unexpected error occurred when running Llama Guard 3 guardrail: {e}")
97
+ return True, "Unexpected error occurred when running Llama Guard 3 guardrail."
98
+
99
+
100
+ def parse_args():
101
+ parser = argparse.ArgumentParser()
102
+ parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
103
+ parser.add_argument(
104
+ "--checkpoint_dir",
105
+ type=str,
106
+ help="Path to the Llama Guard 3 checkpoint folder",
107
+ )
108
+ return parser.parse_args()
109
+
110
+
111
+ def main(args):
112
+ llamaGuard3 = LlamaGuard3(checkpoint_dir=args.checkpoint_dir)
113
+ runner = GuardrailRunner(safety_models=[llamaGuard3])
114
+ with misc.timer("Llama Guard 3 safety check"):
115
+ safety, message = runner.run_safety_check(args.prompt)
116
+ log.info(f"Input is: {'SAFE' if safety else 'UNSAFE'}")
117
+ log.info(f"Message: {message}") if not safety else None
118
+
119
+
120
+ if __name__ == "__main__":
121
+ args = parse_args()
122
+ main(args)
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import attrs
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from cosmos_transfer1.utils.ddp_config import make_freezable
21
+
22
+
23
+ @make_freezable
24
+ @attrs.define(slots=False)
25
+ class ModelConfig:
26
+ input_size: int = 1152
27
+ num_classes: int = 7
28
+
29
+
30
+ class SafetyClassifier(nn.Module):
31
+ def __init__(self, input_size: int = 1024, num_classes: int = 2):
32
+ super().__init__()
33
+ self.input_size = input_size
34
+ self.num_classes = num_classes
35
+ self.layers = nn.Sequential(
36
+ nn.Linear(self.input_size, 512),
37
+ nn.BatchNorm1d(512),
38
+ nn.ReLU(),
39
+ nn.Linear(512, 256),
40
+ nn.BatchNorm1d(256),
41
+ nn.ReLU(),
42
+ nn.Linear(256, self.num_classes),
43
+ # Note: No activation function here; CrossEntropyLoss expects raw logits
44
+ )
45
+
46
+ def forward(self, x):
47
+ return self.layers(x)
48
+
49
+
50
+ class VideoSafetyModel(nn.Module):
51
+ def __init__(self, config: ModelConfig) -> None:
52
+ super().__init__()
53
+ self.config = config
54
+ self.num_classes = config.num_classes
55
+ self.network = SafetyClassifier(input_size=config.input_size, num_classes=self.num_classes)
56
+
57
+ @torch.inference_mode()
58
+ def forward(self, data_batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
59
+ logits = self.network(data_batch["data"].cuda())
60
+ return {"logits": logits}
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/video_content_safety_filter.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ from typing import Iterable, Tuple, Union
20
+
21
+ import torch
22
+ from PIL import Image
23
+
24
+ from cosmos_transfer1.auxiliary.guardrail.common.core import ContentSafetyGuardrail, GuardrailRunner
25
+ from cosmos_transfer1.auxiliary.guardrail.common.io_utils import get_video_filepaths, read_video
26
+ from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.model import ModelConfig, VideoSafetyModel
27
+ from cosmos_transfer1.auxiliary.guardrail.video_content_safety_filter.vision_encoder import SigLIPEncoder
28
+ from cosmos_transfer1.utils import log, misc
29
+
30
+ # Define the class index to class name mapping for multi-class classification
31
+ CLASS_IDX_TO_NAME = {
32
+ 0: "Safe",
33
+ 1: "Sexual_Content",
34
+ 3: "Drugs",
35
+ 4: "Child_Abuse",
36
+ 5: "Hate_and_Harassment",
37
+ 6: "Self-Harm",
38
+ }
39
+
40
+
41
+ class VideoContentSafetyFilter(ContentSafetyGuardrail):
42
+ def __init__(
43
+ self,
44
+ checkpoint_dir: str,
45
+ device="cuda" if torch.cuda.is_available() else "cpu",
46
+ ) -> None:
47
+ self.checkpoint_dir = os.path.join(checkpoint_dir, "nvidia/Cosmos-Guardrail1/video_content_safety_filter")
48
+ self.device = device
49
+ self.dtype = torch.float32
50
+
51
+ # Initialize the SigLIP encoder
52
+ self.encoder = SigLIPEncoder(checkpoint_dir=self.checkpoint_dir, device=device, dtype=self.dtype)
53
+
54
+ # Use ModelConfig directly for inference configuration
55
+ model_config = ModelConfig(input_size=1152, num_classes=7)
56
+
57
+ # Load the multi-class classifier
58
+ self.model = VideoSafetyModel(model_config)
59
+ safety_filter_local_path = os.path.join(self.checkpoint_dir, "safety_filter.pt")
60
+ checkpoint = torch.load(safety_filter_local_path, map_location=torch.device("cpu"), weights_only=True)
61
+ self.model.load_state_dict(checkpoint["model"])
62
+ self.model.to(self.device, dtype=self.dtype).eval()
63
+
64
+ @torch.inference_mode()
65
+ def __infer(self, pil_image: Image.Image) -> int:
66
+ """Infer the class of the image."""
67
+ image_embs = self.encoder.encode_image(pil_image)
68
+ logits = self.model.network(image_embs)
69
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
70
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
71
+ return predicted_class
72
+
73
+ def is_safe_file(self, filepath: str) -> bool:
74
+ """Check if the video file is safe."""
75
+ video_data = read_video(filepath)
76
+
77
+ # Sample frames at 2 FPS
78
+ sample_rate = 2 # frames per second
79
+ frame_interval = int(video_data.fps / sample_rate)
80
+ frame_numbers = list(range(0, int(video_data.fps * video_data.duration), frame_interval))
81
+
82
+ is_safe = True
83
+ frame_scores = []
84
+
85
+ for frame_number in frame_numbers:
86
+ try:
87
+ frame = video_data.frames[frame_number]
88
+ pil_image = Image.fromarray(frame)
89
+ predicted_class = self.__infer(pil_image)
90
+ class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Safe")
91
+ frame_scores.append({"frame_number": frame_number, "class": class_name})
92
+
93
+ # If any frame is not "Safe", mark the video as unsafe
94
+ if class_name != "Safe":
95
+ is_safe = False
96
+ break
97
+
98
+ except Exception as e:
99
+ log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}")
100
+ continue
101
+
102
+ # Prepare data for JSON
103
+ video_data = {
104
+ "filepath": filepath,
105
+ "is_safe": is_safe,
106
+ "video_length": video_data.duration,
107
+ "fps": video_data.fps,
108
+ "frame_scores": frame_scores,
109
+ }
110
+
111
+ log.info(f"Video {filepath} is {'SAFE' if is_safe else 'UNSAFE'}.")
112
+ log.debug(f"Video data: {json.dumps(video_data, indent=4)}")
113
+ return is_safe
114
+
115
+ def is_safe_frames(self, frames: Iterable) -> bool:
116
+ """Check if the generated video frames are safe."""
117
+ frame_scores = []
118
+ total_frames = 0
119
+ safe_frames = 0
120
+
121
+ for frame_number, frame in enumerate(frames):
122
+ try:
123
+ total_frames += 1
124
+ pil_image = Image.fromarray(frame)
125
+ predicted_class = self.__infer(pil_image)
126
+ class_name = CLASS_IDX_TO_NAME.get(predicted_class, "Safe")
127
+ frame_scores.append({"frame_number": frame_number, "class": class_name})
128
+
129
+ if class_name == "Safe":
130
+ safe_frames += 1
131
+
132
+ except Exception as e:
133
+ log.warning(f"Warning: Failed to run safety classifier on frame_number {frame_number}. Exception: {e}")
134
+ continue
135
+
136
+ # Decide if the video is safe based on the ratio of safe frames
137
+ is_safe = False
138
+ if total_frames > 0:
139
+ is_safe = (safe_frames / total_frames) >= 0.95
140
+
141
+ video_data = {
142
+ "is_safe": is_safe,
143
+ "frame_scores": frame_scores,
144
+ }
145
+
146
+ log.debug(f"Frames data: {json.dumps(video_data, indent=4)}")
147
+ return is_safe
148
+
149
+ def is_safe(self, input: Union[str, Iterable]) -> Tuple[bool, str]:
150
+ if isinstance(input, str):
151
+ is_safe = self.is_safe_file(input)
152
+ return is_safe, "safe video detected" if is_safe else "unsafe video detected"
153
+ else:
154
+ is_safe = self.is_safe_frames(input)
155
+ return is_safe, "safe frames detected" if is_safe else "unsafe frames detected"
156
+
157
+
158
+ def parse_args():
159
+ parser = argparse.ArgumentParser()
160
+ parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos")
161
+ parser.add_argument(
162
+ "--checkpoint_dir",
163
+ type=str,
164
+ help="Path to the Video Content Safety Filter checkpoint folder",
165
+ )
166
+ return parser.parse_args()
167
+
168
+
169
+ def main(args):
170
+ filepaths = get_video_filepaths(args.input_dir)
171
+ if not filepaths:
172
+ log.error(f"No video files found in directory: {args.input_dir}")
173
+ return
174
+
175
+ video_filter = VideoContentSafetyFilter(checkpoint_dir=args.checkpoint_dir)
176
+ runner = GuardrailRunner(safety_models=[video_filter], generic_safe_msg="Video is safe")
177
+
178
+ for filepath in filepaths:
179
+ with misc.timer("video content safety filter"):
180
+ _ = runner.run_safety_check(filepath)
181
+
182
+
183
+ if __name__ == "__main__":
184
+ args = parse_args()
185
+ main(args)
cosmos_transfer1/auxiliary/guardrail/video_content_safety_filter/vision_encoder.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import torch
19
+ from PIL import Image
20
+ from transformers import SiglipModel, SiglipProcessor
21
+
22
+
23
+ class SigLIPEncoder(torch.nn.Module):
24
+ def __init__(
25
+ self,
26
+ checkpoint_dir: str,
27
+ model_name: str = "google/siglip-so400m-patch14-384",
28
+ device="cuda" if torch.cuda.is_available() else "cpu",
29
+ dtype=torch.float32,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.checkpoint_dir = checkpoint_dir
33
+ self.device = device
34
+ self.dtype = dtype
35
+ self.model = SiglipModel.from_pretrained(model_name, cache_dir=self.checkpoint_dir)
36
+ self.processor = SiglipProcessor.from_pretrained(model_name, cache_dir=self.checkpoint_dir)
37
+ self.model.to(self.device, dtype=self.dtype).eval()
38
+
39
+ @torch.inference_mode()
40
+ def encode_image(self, input_img: Image.Image) -> torch.Tensor:
41
+ """Encode an image into a feature vector."""
42
+ with torch.no_grad():
43
+ inputs = self.processor(images=input_img, return_tensors="pt").to(self.device, dtype=self.dtype)
44
+ image_features = self.model.get_image_features(**inputs)
45
+ image_features /= image_features.norm(dim=-1, keepdim=True)
46
+ return image_features
cosmos_transfer1/auxiliary/human_keypoint/human_keypoint.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import cv2
19
+ import numpy as np
20
+ from rtmlib import Wholebody
21
+
22
+ from cosmos_transfer1.diffusion.datasets.augmentors.human_keypoint_utils import (
23
+ coco_wholebody_133_skeleton,
24
+ openpose134_skeleton,
25
+ )
26
+ from cosmos_transfer1.utils import log
27
+
28
+
29
+ class HumanKeypointModel:
30
+ def __init__(self, to_openpose=True, conf_thres=0.6):
31
+ self.model = Wholebody(
32
+ to_openpose=to_openpose,
33
+ mode="performance",
34
+ backend="onnxruntime",
35
+ device="cuda",
36
+ )
37
+ self.to_openpose = to_openpose
38
+ self.conf_thres = conf_thres
39
+
40
+ def __call__(self, input_video: str, output_video: str = "keypoint.mp4") -> str:
41
+ """
42
+ Generate the human body keypoint plot for the keypointControlNet video2world model.
43
+ Input: mp4 video
44
+ Output: mp4 keypoint video, of the same spatial and temporal dimensions as the input video.
45
+ """
46
+
47
+ log.info(f"Processing video: {input_video} to generate keypoint video: {output_video}")
48
+ assert os.path.exists(input_video)
49
+
50
+ cap = cv2.VideoCapture(input_video)
51
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
52
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
53
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
54
+ frame_size = (frame_width, frame_height)
55
+
56
+ # vid writer
57
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
58
+ skeleton_writer = cv2.VideoWriter(output_video, fourcc, fps, frame_size)
59
+
60
+ log.info(f"frame width: {frame_width}, frame height: {frame_height}, fps: {fps}")
61
+ log.info("start pose estimation for frames..")
62
+
63
+ # Process each frame
64
+ while cap.isOpened():
65
+ ret, frame = cap.read()
66
+ if not ret:
67
+ break
68
+
69
+ # Create a black background frame
70
+ black_frame = np.zeros_like(frame)
71
+
72
+ # Run pose estimation
73
+ keypoints, scores = self.model(frame)
74
+
75
+ if keypoints is not None and len(keypoints) > 0:
76
+ skeleton_frame = self.plot_person_kpts(
77
+ black_frame,
78
+ keypoints,
79
+ scores,
80
+ kpt_thr=self.conf_thres,
81
+ openpose_format=True,
82
+ line_width=4,
83
+ ) # (h, w, 3)
84
+ else:
85
+ skeleton_frame = black_frame
86
+
87
+ skeleton_writer.write(skeleton_frame[:, :, ::-1])
88
+
89
+ cap.release()
90
+ skeleton_writer.release()
91
+
92
+ def draw_skeleton(
93
+ self,
94
+ img: np.ndarray,
95
+ keypoints: np.ndarray,
96
+ scores: np.ndarray,
97
+ kpt_thr: float = 0.6,
98
+ openpose_format: bool = True,
99
+ radius: int = 2,
100
+ line_width: int = 4,
101
+ ):
102
+ skeleton_topology = openpose134_skeleton if openpose_format else coco_wholebody_133_skeleton
103
+ assert len(keypoints.shape) == 2
104
+ keypoint_info, skeleton_info = (
105
+ skeleton_topology["keypoint_info"],
106
+ skeleton_topology["skeleton_info"],
107
+ )
108
+ vis_kpt = [s >= kpt_thr for s in scores]
109
+ link_dict = {}
110
+ for i, kpt_info in keypoint_info.items():
111
+ kpt_color = tuple(kpt_info["color"])
112
+ link_dict[kpt_info["name"]] = kpt_info["id"]
113
+
114
+ kpt = keypoints[i]
115
+
116
+ if vis_kpt[i]:
117
+ img = cv2.circle(img, (int(kpt[0]), int(kpt[1])), int(radius), kpt_color, -1)
118
+
119
+ for i, ske_info in skeleton_info.items():
120
+ link = ske_info["link"]
121
+ pt0, pt1 = link_dict[link[0]], link_dict[link[1]]
122
+
123
+ if vis_kpt[pt0] and vis_kpt[pt1]:
124
+ link_color = ske_info["color"]
125
+ kpt0 = keypoints[pt0]
126
+ kpt1 = keypoints[pt1]
127
+
128
+ img = cv2.line(
129
+ img, (int(kpt0[0]), int(kpt0[1])), (int(kpt1[0]), int(kpt1[1])), link_color, thickness=line_width
130
+ )
131
+
132
+ return img
133
+
134
+ def plot_person_kpts(
135
+ self,
136
+ pose_vis_img: np.ndarray,
137
+ keypoints: np.ndarray,
138
+ scores: np.ndarray,
139
+ kpt_thr: float = 0.6,
140
+ openpose_format: bool = True,
141
+ line_width: int = 4,
142
+ ) -> np.ndarray:
143
+ """
144
+ plot a single person
145
+ in-place update the pose image
146
+ """
147
+ for kpts, ss in zip(keypoints, scores):
148
+ try:
149
+ pose_vis_img = self.draw_skeleton(
150
+ pose_vis_img, kpts, ss, kpt_thr=kpt_thr, openpose_format=openpose_format, line_width=line_width
151
+ )
152
+ except ValueError as e:
153
+ log.error(f"Error in draw_skeleton func, {e}")
154
+
155
+ return pose_vis_img
cosmos_transfer1/auxiliary/robot_augmentation/README.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Robot Data Augmentation with Cosmos-Transfer1
2
+
3
+ This pipeline provides a two-step process to augment robotic videos using **Cosmos-Transfer1-7B**. It leverages **spatial-temporal control** to modify backgrounds while preserving the shape and/or appearance of the robot foreground.
4
+
5
+ ## Overview of Settings
6
+
7
+ We propose two augmentation settings:
8
+
9
+ ### Setting 1 (fg_vis_edge_bg_seg): Preserve Shape and Appearance of the Robot (foreground)
10
+ - **Foreground Controls**: `Edge`, `Vis`
11
+ - **Background Controls**: `Segmentation`
12
+ - **Weights**:
13
+ - `w_edge(FG) = 1`
14
+ - `w_vis(FG) = 1`
15
+ - `w_seg(BG) = 1`
16
+ - All other weights = 0
17
+
18
+ ### Setting 2 (fg_edge_bg_seg): Preserve Only Shape of the Robot (foreground)
19
+ - **Foreground Controls**: `Edge`
20
+ - **Background Controls**: `Segmentation`
21
+ - **Weights**:
22
+ - `w_edge(FG) = 1`
23
+ - `w_seg(BG) = 1`
24
+ - All other weights = 0
25
+
26
+ ## Step-by-Step Instructions
27
+
28
+ ### Step 1: Generate Spatial-Temporal Weights
29
+
30
+ This script extracts foreground (robot) and background information from semantic segmentation data. It processes per-frame segmentation masks and color-to-class mappings to generate spatial-temporal weight matrices for each control modality based on the selected setting.
31
+
32
+ #### Input Requirements:
33
+ - A `segmentation` folder containing per-frame segmentation masks in PNG format
34
+ - A `segmentation_label` folder containing color-to-class mapping JSON files for each frame, for example:
35
+ ```json
36
+ {
37
+ "(29, 0, 0, 255)": {
38
+ "class": "gripper0_right_r_palm_vis"
39
+ },
40
+ "(31, 0, 0, 255)": {
41
+ "class": "gripper0_right_R_thumb_proximal_base_link_vis"
42
+ },
43
+ "(33, 0, 0, 255)": {
44
+ "class": "gripper0_right_R_thumb_proximal_link_vis"
45
+ }
46
+ }
47
+ ```
48
+ - An input video file
49
+
50
+ Here is an example input format:
51
+ [Example input directory](https://github.com/google-deepmind/cosmos/tree/main/assets/robot_augmentation_example/example1)
52
+
53
+ #### Usage
54
+
55
+ ```bash
56
+ PYTHONPATH=$(pwd) python cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py \
57
+ --setting setting1 \
58
+ --robot-keywords world_robot gripper robot \
59
+ --input-dir assets/robot_augmentation_example \
60
+ --output-dir outputs/robot_augmentation_example
61
+ ```
62
+
63
+ #### Parameters:
64
+
65
+ * `--setting`: Weight setting to use (choices: 'setting1', 'setting2', default: 'setting1')
66
+ * setting1: Emphasizes robot in visual and edge features (vis: 1.0 foreground, edge: 1.0 foreground, seg: 1.0 background)
67
+ * setting2: Emphasizes robot only in edge features (edge: 1.0 foreground, seg: 1.0 background)
68
+
69
+ * `--input-dir`: Input directory containing example folders
70
+ * Default: 'assets/robot_augmentation_example'
71
+
72
+ * `--output-dir`: Output directory for weight matrices
73
+ * Default: 'outputs/robot_augmentation_example'
74
+
75
+ * `--robot-keywords`: Keywords used to identify robot classes
76
+ * Default: ["world_robot", "gripper", "robot"]
77
+ * Any semantic class containing these keywords will be treated as robot foreground
78
+
79
+ ### Step 2: Run Cosmos-Transfer1 Inference
80
+
81
+ Use the generated spatial-temporal weight matrices to perform video augmentation with the proper controls.
82
+
83
+ ```bash
84
+ export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:=0}"
85
+ export CHECKPOINT_DIR="${CHECKPOINT_DIR:=./checkpoints}"
86
+ export NUM_GPU="${NUM_GPU:=1}"
87
+
88
+ PYTHONPATH=$(pwd) torchrun --nproc_per_node=$NUM_GPU --nnodes=1 --node_rank=0 \
89
+ cosmos_transfer1/diffusion/inference/transfer.py \
90
+ --checkpoint_dir $CHECKPOINT_DIR \
91
+ --video_save_folder outputs/robot_example_spatial_temporal_setting1 \
92
+ --controlnet_specs assets/robot_augmentation_example/example1/inference_cosmos_transfer1_robot_spatiotemporal_weights.json \
93
+ --offload_text_encoder_model \
94
+ --offload_guardrail_models \
95
+ --num_gpus $NUM_GPU
96
+ ```
97
+
98
+ - Augmented videos are saved in `outputs/robot_example_spatial_temporal_setting1/`
99
+
100
+ ## Input Outputs Example
101
+
102
+ Input video:
103
+
104
+ <video src="https://github.com/user-attachments/assets/9c2df99d-7d0c-4dcf-af87-4ec9f65328ed">
105
+ Your browser does not support the video tag.
106
+ </video>
107
+
108
+ You can run multiple times with different prompts (e.g., `assets/robot_augmentation_example/example1/example1_prompts.json`), and you can get different augmentation results:
109
+
110
+ <video src="https://github.com/user-attachments/assets/6dee15f5-9d8b-469a-a92a-3419cb466d44">
111
+ Your browser does not support the video tag.
112
+ </video>
cosmos_transfer1/auxiliary/robot_augmentation/spatial_temporal_weight.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This script processes segmentation results for each video frame saved as JSON files and generates a spatial-temporal weight matrix saved as a .pt file.
17
+ # The input JSON files contain segmentation information for each frame, and the output .pt file represents the spatial-temporal weight matrix for the video.
18
+
19
+ import argparse
20
+ import glob
21
+ import json
22
+ import logging
23
+ import os
24
+ import re
25
+ from collections import defaultdict
26
+
27
+ import cv2
28
+ import numpy as np
29
+ import torch
30
+ from tqdm import tqdm
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # Class to manage different weight settings
38
+ class WeightSettings:
39
+ """Class to manage different weight settings for the features"""
40
+
41
+ @staticmethod
42
+ def get_settings(setting_name):
43
+ """Get weight settings by name
44
+
45
+ Args:
46
+ setting_name (str): Name of the setting
47
+
48
+ Returns:
49
+ dict: Dictionary with weights for each feature
50
+ """
51
+ settings = {
52
+ # Default setting: Emphasize robot in all features
53
+ "fg_vis_edge_bg_seg": {
54
+ "depth": {"foreground": 0.0, "background": 0.0},
55
+ "vis": {"foreground": 1.0, "background": 0.0},
56
+ "edge": {"foreground": 1.0, "background": 0.0},
57
+ "seg": {"foreground": 0.0, "background": 1.0},
58
+ },
59
+ "fg_edge_bg_seg": {
60
+ "depth": {"foreground": 0.0, "background": 0.0},
61
+ "vis": {"foreground": 0.0, "background": 0.0},
62
+ "edge": {"foreground": 1.0, "background": 0.0},
63
+ "seg": {"foreground": 0.0, "background": 1.0},
64
+ },
65
+ }
66
+
67
+ if setting_name not in settings:
68
+ logger.warning(f"Setting '{setting_name}' not found. Using default.")
69
+ return settings["fg_vis_edge_bg_seg"]
70
+
71
+ return settings[setting_name]
72
+
73
+ @staticmethod
74
+ def list_settings():
75
+ """List all available settings
76
+
77
+ Returns:
78
+ list: List of setting names
79
+ """
80
+ return ["fg_vis_edge_bg_seg", "fg_edge_bg_seg"]
81
+
82
+
83
+ def get_video_info(video_path):
84
+ """Get video dimensions and frame count"""
85
+ cap = cv2.VideoCapture(video_path)
86
+ if not cap.isOpened():
87
+ raise ValueError(f"Could not open video file: {video_path}")
88
+
89
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
90
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
91
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
92
+ fps = cap.get(cv2.CAP_PROP_FPS)
93
+
94
+ cap.release()
95
+ return width, height, frame_count, fps
96
+
97
+
98
+ def parse_color_key(color_key):
99
+ """Parse a color key string into an RGB tuple
100
+
101
+ Args:
102
+ color_key (str): Color key string in the format "(r,g,b,a)" or similar
103
+
104
+ Returns:
105
+ tuple: RGB tuple (r, g, b)
106
+ """
107
+ # Extract numbers using regex to handle different formats
108
+ numbers = re.findall(r"\d+", color_key)
109
+ if len(numbers) >= 3:
110
+ r, g, b = map(int, numbers[:3])
111
+ return (r, g, b)
112
+ else:
113
+ raise ValueError(f"Invalid color key format: {color_key}")
114
+
115
+
116
+ def save_visualization(mask, frame_num, feature_name, viz_dir):
117
+ """Save a visualization of the binary mask
118
+
119
+ Args:
120
+ mask (numpy.ndarray): The mask (values 0 or 255)
121
+ frame_num (int): The frame number
122
+ feature_name (str): The name of the feature (depth, vis, edge, seg)
123
+ viz_dir (str): Directory to save visualizations
124
+ """
125
+ # Simply save the binary mask directly
126
+ output_path = os.path.join(viz_dir, f"{feature_name}_frame_{frame_num:06d}.png")
127
+ cv2.imwrite(output_path, mask)
128
+ logger.info(f"Saved binary visualization to {output_path}")
129
+
130
+
131
+ def process_segmentation_files(
132
+ segmentation_dir,
133
+ output_dir,
134
+ viz_dir,
135
+ video_path=None,
136
+ weights_dict=None,
137
+ setting_name="fg_vis_edge_bg_seg",
138
+ robot_keywords=None,
139
+ ):
140
+ """Process all segmentation JSON files and create weight matrices
141
+
142
+ Args:
143
+ segmentation_dir (str): Directory containing segmentation JSON files
144
+ output_dir (str): Directory to save weight matrices
145
+ viz_dir (str): Directory to save visualizations
146
+ video_path (str, optional): Path to the video file. Defaults to None.
147
+ weights_dict (dict, optional): Dictionary with weights for each feature.
148
+ Format: {
149
+ 'depth': {'foreground': float, 'background': float},
150
+ 'vis': {'foreground': float, 'background': float},
151
+ 'edge': {'foreground': float, 'background': float},
152
+ 'seg': {'foreground': float, 'background': float}
153
+ }
154
+ Values should be in range 0-1. Defaults to None.
155
+ setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg (setting1)'.
156
+ robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to ["robot"].
157
+ """
158
+
159
+ # Set default robot keywords if not provided
160
+ if robot_keywords is None:
161
+ robot_keywords = ["robot"]
162
+
163
+ # Get all JSON files
164
+ json_files = sorted(glob.glob(os.path.join(segmentation_dir, "*.json")))
165
+ logger.info(f"Found {len(json_files)} JSON files")
166
+
167
+ if len(json_files) == 0:
168
+ raise ValueError(f"No JSON files found in {segmentation_dir}")
169
+
170
+ # For example directories, check for PNG files
171
+ png_dir = os.path.join(os.path.dirname(segmentation_dir), "segmentation")
172
+ png_files = []
173
+ if os.path.exists(png_dir):
174
+ png_files = sorted(glob.glob(os.path.join(png_dir, "*.png")))
175
+ logger.info(f"Found {len(png_files)} PNG files in segmentation directory")
176
+
177
+ # Step 1: Create a unified color-to-class mapping from all JSON files
178
+ logger.info("Creating unified color-to-class mapping...")
179
+ rgb_to_class = {}
180
+ rgb_to_is_robot = {}
181
+
182
+ for json_file in tqdm(json_files, desc="Processing JSON files for unified mapping"):
183
+ with open(json_file, "r") as f:
184
+ json_data = json.load(f)
185
+
186
+ for color_key, data in json_data.items():
187
+ color = parse_color_key(color_key)
188
+ class_name = data["class"]
189
+
190
+ # Store RGB color for matching
191
+ rgb_to_class[color] = class_name
192
+ rgb_to_is_robot[color] = any(keyword in class_name for keyword in robot_keywords)
193
+
194
+ # Print statistics about the unified color mapping
195
+ robot_colors = [color for color, is_robot in rgb_to_is_robot.items() if is_robot]
196
+ logger.info(f"Unified mapping: Found {len(robot_colors)} robot colors out of {len(rgb_to_is_robot)} total colors")
197
+ if robot_colors:
198
+ logger.info(f"Robot classes: {[rgb_to_class[color] for color in robot_colors]}")
199
+
200
+ # Convert color mapping to arrays for vectorized operations
201
+ colors = list(rgb_to_is_robot.keys())
202
+ color_array = np.array(colors)
203
+ is_robot_array = np.array([rgb_to_is_robot[color] for color in colors], dtype=bool)
204
+
205
+ # If we have PNG files, get dimensions from the first PNG
206
+ if png_files:
207
+ # Get dimensions from the first PNG file
208
+ first_png = cv2.imread(png_files[0])
209
+ if first_png is None:
210
+ raise ValueError(f"Could not read PNG file: {png_files[0]}")
211
+
212
+ height, width = first_png.shape[:2]
213
+ frame_count = len(png_files)
214
+
215
+ # Match frame numbers between JSON and PNG files to ensure correct correspondence
216
+ json_frame_nums = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in json_files]
217
+ png_frame_nums = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in png_files]
218
+
219
+ # Find common frames between JSON and PNG files
220
+ common_frames = sorted(set(json_frame_nums).intersection(set(png_frame_nums)))
221
+ logger.info(f"Found {len(common_frames)} common frames between JSON and PNG files")
222
+
223
+ if len(common_frames) == 0:
224
+ raise ValueError("No matching frames found between JSON and PNG files")
225
+
226
+ # Create maps to easily look up files by frame number
227
+ json_map = {int(os.path.basename(f).split("_")[-1].split(".")[0]): f for f in json_files}
228
+ png_map = {int(os.path.basename(f).split("_")[-1].split(".")[0]): f for f in png_files}
229
+
230
+ # Create new lists with only matching files
231
+ json_files = [json_map[frame] for frame in common_frames if frame in json_map]
232
+ png_files = [png_map[frame] for frame in common_frames if frame in png_map]
233
+ num_frames = len(json_files)
234
+
235
+ logger.info(f"Using PNG dimensions: {width}x{height}, processing {num_frames} frames")
236
+ else:
237
+ # Get video information if no PNG files available
238
+ try:
239
+ width, height, frame_count, fps = get_video_info(video_path)
240
+ logger.info(f"Video dimensions: {width}x{height}, {frame_count} frames, {fps} fps")
241
+ num_frames = min(len(json_files), frame_count)
242
+ except Exception as e:
243
+ logger.warning(f"Warning: Could not get video information: {e}")
244
+ # Use a default size if we can't get the video info
245
+ width, height = 640, 480
246
+ num_frames = len(json_files)
247
+ logger.info(f"Using default dimensions: {width}x{height}, {num_frames} frames")
248
+
249
+ # Initialize weight tensors
250
+ depth_weights = torch.zeros((num_frames, height, width))
251
+ vis_weights = torch.zeros((num_frames, height, width))
252
+ edge_weights = torch.zeros((num_frames, height, width))
253
+ seg_weights = torch.zeros((num_frames, height, width))
254
+
255
+ # Process frames
256
+ if png_files:
257
+ # Process PNG files directly
258
+ for i, (json_file, png_file) in enumerate(zip(json_files, png_files)):
259
+ # Get frame number from filename
260
+ frame_num = int(os.path.basename(json_file).split("_")[-1].split(".")[0])
261
+
262
+ # Read the corresponding PNG file
263
+ frame = cv2.imread(png_file)
264
+
265
+ if frame is None:
266
+ logger.warning(f"Warning: Could not read frame {i} from PNG. Using blank frame.")
267
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
268
+
269
+ # Convert frame to RGB
270
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
271
+
272
+ # Calculate total pixels
273
+ total_pixels = height * width
274
+
275
+ # Vectorized approach for finding nearest colors
276
+ # Convert frame_rgb to a 2D array of shape (height*width, 3)
277
+ pixels = frame_rgb.reshape(-1, 3)
278
+
279
+ # Calculate distances between each pixel and each color (vectorized)
280
+ # This creates a matrix of shape (height*width, num_colors)
281
+ distances = np.sqrt(np.sum((pixels[:, np.newaxis, :] - color_array[np.newaxis, :, :]) ** 2, axis=2))
282
+
283
+ # Find the index of the nearest color for each pixel
284
+ nearest_color_indices = np.argmin(distances, axis=1)
285
+
286
+ # Get the is_robot value for each pixel based on its nearest color
287
+ pixel_is_robot = is_robot_array[nearest_color_indices]
288
+
289
+ # Reshape back to image dimensions
290
+ pixel_is_robot_2d = pixel_is_robot.reshape(height, width)
291
+
292
+ # Count robot and matched pixels
293
+ robot_pixel_count = np.sum(pixel_is_robot)
294
+ matched_pixel_count = pixels.shape[0] # All pixels are matched now
295
+
296
+ # Create masks based on the is_robot classification
297
+ depth_mask = np.where(
298
+ pixel_is_robot_2d, weights_dict["depth"]["foreground"], weights_dict["depth"]["background"]
299
+ )
300
+
301
+ vis_mask = np.where(pixel_is_robot_2d, weights_dict["vis"]["foreground"], weights_dict["vis"]["background"])
302
+
303
+ edge_mask = np.where(
304
+ pixel_is_robot_2d, weights_dict["edge"]["foreground"], weights_dict["edge"]["background"]
305
+ )
306
+
307
+ seg_mask = np.where(pixel_is_robot_2d, weights_dict["seg"]["foreground"], weights_dict["seg"]["background"])
308
+
309
+ # Create visualization mask
310
+ visualization_mask = np.zeros((height, width), dtype=np.uint8)
311
+ visualization_mask[pixel_is_robot_2d] = 255
312
+
313
+ # Log statistics
314
+ robot_percentage = (robot_pixel_count / total_pixels) * 100
315
+ matched_percentage = (matched_pixel_count / total_pixels) * 100
316
+ logger.info(f"Frame {frame_num}: {robot_pixel_count} robot pixels ({robot_percentage:.2f}%)")
317
+ logger.info(f"Frame {frame_num}: {matched_pixel_count} matched pixels ({matched_percentage:.2f}%)")
318
+
319
+ # Save visualizations for this frame
320
+ save_visualization(visualization_mask, frame_num, "segmentation", viz_dir)
321
+
322
+ # Store the masks in the weight tensors
323
+ depth_weights[i] = torch.from_numpy(depth_mask)
324
+ vis_weights[i] = torch.from_numpy(vis_mask)
325
+ edge_weights[i] = torch.from_numpy(edge_mask)
326
+ seg_weights[i] = torch.from_numpy(seg_mask)
327
+ else:
328
+ # Use video frames if available
329
+ try:
330
+ # Open the segmentation video
331
+ cap = cv2.VideoCapture(video_path)
332
+ if not cap.isOpened():
333
+ raise ValueError(f"Could not open video file: {video_path}")
334
+
335
+ # Process each frame using the unified color mapping
336
+ for i, json_file in enumerate(tqdm(json_files[:num_frames], desc="Processing frames")):
337
+ # Get frame number from filename
338
+ frame_num = int(os.path.basename(json_file).split("_")[-1].split(".")[0])
339
+
340
+ # Read the corresponding frame from the video
341
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
342
+ ret, frame = cap.read()
343
+
344
+ if not ret:
345
+ logger.warning(f"Warning: Could not read frame {i} from video. Using blank frame.")
346
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
347
+
348
+ # Convert frame to RGB
349
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
350
+
351
+ # Calculate total pixels
352
+ total_pixels = height * width
353
+
354
+ # Vectorized approach for finding nearest colors
355
+ pixels = frame_rgb.reshape(-1, 3)
356
+ distances = np.sqrt(np.sum((pixels[:, np.newaxis, :] - color_array[np.newaxis, :, :]) ** 2, axis=2))
357
+ nearest_color_indices = np.argmin(distances, axis=1)
358
+ pixel_is_robot = is_robot_array[nearest_color_indices]
359
+ pixel_is_robot_2d = pixel_is_robot.reshape(height, width)
360
+
361
+ # Count robot and matched pixels
362
+ robot_pixel_count = np.sum(pixel_is_robot)
363
+ matched_pixel_count = pixels.shape[0]
364
+
365
+ # Create masks based on the is_robot classification
366
+ depth_mask = np.where(
367
+ pixel_is_robot_2d, weights_dict["depth"]["foreground"], weights_dict["depth"]["background"]
368
+ )
369
+ vis_mask = np.where(
370
+ pixel_is_robot_2d, weights_dict["vis"]["foreground"], weights_dict["vis"]["background"]
371
+ )
372
+ edge_mask = np.where(
373
+ pixel_is_robot_2d, weights_dict["edge"]["foreground"], weights_dict["edge"]["background"]
374
+ )
375
+ seg_mask = np.where(
376
+ pixel_is_robot_2d, weights_dict["seg"]["foreground"], weights_dict["seg"]["background"]
377
+ )
378
+
379
+ # Create visualization mask
380
+ visualization_mask = np.zeros((height, width), dtype=np.uint8)
381
+ visualization_mask[pixel_is_robot_2d] = 255
382
+
383
+ # Log statistics
384
+ robot_percentage = (robot_pixel_count / total_pixels) * 100
385
+ matched_percentage = (matched_pixel_count / total_pixels) * 100
386
+ logger.info(f"Frame {frame_num}: {robot_pixel_count} robot pixels ({robot_percentage:.2f}%)")
387
+ logger.info(f"Frame {frame_num}: {matched_pixel_count} matched pixels ({matched_percentage:.2f}%)")
388
+
389
+ # Save visualizations for this frame
390
+ save_visualization(visualization_mask, frame_num, "segmentation", viz_dir)
391
+
392
+ # Store the masks in the weight tensors
393
+ depth_weights[i] = torch.from_numpy(depth_mask)
394
+ vis_weights[i] = torch.from_numpy(vis_mask)
395
+ edge_weights[i] = torch.from_numpy(edge_mask)
396
+ seg_weights[i] = torch.from_numpy(seg_mask)
397
+
398
+ # Close the video capture
399
+ cap.release()
400
+ except Exception as e:
401
+ logger.warning(f"Warning: Error processing video: {e}")
402
+ logger.warning("Cannot process this example without proper frame data.")
403
+ raise ValueError(f"Cannot process example without frame data: {e}")
404
+
405
+ # Save weight tensors
406
+ # Convert weights to half precision (float16) to reduce file size
407
+ depth_weights_half = depth_weights.to(torch.float16)
408
+ vis_weights_half = vis_weights.to(torch.float16)
409
+ edge_weights_half = edge_weights.to(torch.float16)
410
+ seg_weights_half = seg_weights.to(torch.float16)
411
+
412
+ # Save the half precision tensors
413
+ torch.save(depth_weights_half, os.path.join(output_dir, "depth_weights.pt"))
414
+ torch.save(vis_weights_half, os.path.join(output_dir, "vis_weights.pt"))
415
+ torch.save(edge_weights_half, os.path.join(output_dir, "edge_weights.pt"))
416
+ torch.save(seg_weights_half, os.path.join(output_dir, "seg_weights.pt"))
417
+
418
+ logger.info(f"Saved weight matrices to {output_dir}")
419
+ logger.info(f"Weight matrix shape: {depth_weights_half.shape}, dtype: {depth_weights_half.dtype}")
420
+ logger.info(f"Saved visualizations to {viz_dir}")
421
+
422
+ return output_dir, viz_dir
423
+
424
+
425
+ def process_all_examples(input_dir, output_dir, setting_name="fg_vis_edge_bg_seg", robot_keywords=None):
426
+ """Process all example directories in the provided input directory
427
+
428
+ Args:
429
+ input_dir (str): Input directory containing example folders
430
+ output_dir (str): Output directory for weight matrices
431
+ setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg'.
432
+ robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to None.
433
+ """
434
+ # Find all example directories
435
+ if not os.path.exists(input_dir):
436
+ logger.error(f"Input directory not found: {input_dir}")
437
+ return []
438
+
439
+ # List example directories
440
+ examples = [d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))]
441
+ examples = sorted(examples)
442
+
443
+ if not examples:
444
+ logger.warning("No example directories found.")
445
+ return []
446
+
447
+ # Print found examples
448
+ logger.info(f"Found {len(examples)} example directories:")
449
+ for example in examples:
450
+ logger.info(f" - {example}")
451
+
452
+ # Store processing results
453
+ results = []
454
+
455
+ # Process each example
456
+ for example in examples:
457
+ try:
458
+ logger.info(f"\nProcessing {example}...")
459
+
460
+ # Process this example with custom directories
461
+ out_dir, viz_dir = process_example_with_dirs(example, input_dir, output_dir, setting_name, robot_keywords)
462
+ results.append((example, out_dir, viz_dir))
463
+
464
+ logger.info(f"Results for {example} saved to:")
465
+ logger.info(f" Weight matrices: {out_dir}")
466
+ logger.info(f" Visualizations: {viz_dir}")
467
+
468
+ except Exception as e:
469
+ logger.error(f"Error processing {example}: {e}")
470
+
471
+ logger.info("\nAll examples processed.")
472
+ return results
473
+
474
+
475
+ # Process a specific example with custom input and output directories
476
+ def process_example_with_dirs(
477
+ example_name, input_dir, output_dir, setting_name="fg_vis_edge_bg_seg", robot_keywords=None
478
+ ):
479
+ """Process a specific example with custom input and output directories
480
+
481
+ Args:
482
+ example_name (str): Name of the example directory
483
+ input_dir (str): Path to input directory containing example folders
484
+ output_dir (str): Path to output directory for weight matrices
485
+ setting_name (str, optional): Weight setting name. Defaults to 'fg_vis_edge_bg_seg'.
486
+ robot_keywords (list, optional): List of keywords to identify robot classes. Defaults to None.
487
+ """
488
+ # Create paths for this example
489
+ example_dir = os.path.join(input_dir, example_name)
490
+ segmentation_dir = os.path.join(example_dir, "segmentation_label")
491
+ video_path = os.path.join(example_dir, "segmentation.mp4")
492
+
493
+ # Create output directories
494
+ example_output_dir = os.path.join(output_dir, example_name)
495
+ viz_dir = os.path.join(example_output_dir, "visualizations")
496
+
497
+ # Check if weight files already exist
498
+ depth_weights_path = os.path.join(example_output_dir, "depth_weights.pt")
499
+ if os.path.exists(depth_weights_path):
500
+ logger.info(f"Weight files already exist for {example_name}, skipping processing")
501
+ return example_output_dir, viz_dir
502
+
503
+ # Create output directories if they don't exist
504
+ os.makedirs(example_output_dir, exist_ok=True)
505
+ os.makedirs(viz_dir, exist_ok=True)
506
+
507
+ # Get weight settings
508
+ weights_dict = WeightSettings.get_settings(setting_name)
509
+
510
+ # Process this example directly with paths
511
+ return process_segmentation_files(
512
+ segmentation_dir=segmentation_dir,
513
+ output_dir=example_output_dir,
514
+ viz_dir=viz_dir,
515
+ video_path=video_path,
516
+ weights_dict=weights_dict,
517
+ setting_name=setting_name,
518
+ robot_keywords=robot_keywords,
519
+ )
520
+
521
+
522
+ if __name__ == "__main__":
523
+ # Parse command-line arguments
524
+ parser = argparse.ArgumentParser(
525
+ description="Process segmentation files to generate spatial-temporal weight matrices"
526
+ )
527
+ parser.add_argument(
528
+ "--setting",
529
+ type=str,
530
+ default="fg_vis_edge_bg_seg",
531
+ choices=WeightSettings.list_settings(),
532
+ help="Weight setting to use (default: fg_vis_edge_bg_seg (setting1), fg_edge_bg_seg (setting2))",
533
+ )
534
+ parser.add_argument(
535
+ "--input-dir",
536
+ type=str,
537
+ default="assets/robot_augmentation_example",
538
+ help="Input directory containing example folders",
539
+ )
540
+ parser.add_argument(
541
+ "--output-dir",
542
+ type=str,
543
+ default="outputs/robot_augmentation_example",
544
+ help="Output directory for weight matrices",
545
+ )
546
+ parser.add_argument(
547
+ "--robot-keywords",
548
+ type=str,
549
+ nargs="+",
550
+ default=["world_robot", "gripper", "robot"],
551
+ help="Keywords used to identify robot classes (default: world_robot gripper robot)",
552
+ )
553
+ parser.add_argument(
554
+ "--log-level",
555
+ type=str,
556
+ default="INFO",
557
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
558
+ help="Set the logging level",
559
+ )
560
+ args = parser.parse_args()
561
+
562
+ # Set logging level from command line argument
563
+ logger.setLevel(getattr(logging, args.log_level))
564
+
565
+ # Get directories from arguments
566
+ input_dir = args.input_dir
567
+ output_dir = args.output_dir
568
+ setting_name = args.setting
569
+ robot_keywords = args.robot_keywords
570
+
571
+ logger.info(f"Using input directory: {input_dir}")
572
+ logger.info(f"Using output directory: {output_dir}")
573
+ logger.info(f"Using weight setting: {setting_name}")
574
+ logger.info(f"Using robot keywords: {robot_keywords}")
575
+
576
+ # Process all examples with the provided input and output directories
577
+ process_all_examples(input_dir, output_dir, setting_name, robot_keywords)
cosmos_transfer1/auxiliary/sam2/sam2_model.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import sys
18
+
19
+ import numpy as np
20
+ import pycocotools.mask as mask_util
21
+ import torch
22
+
23
+ from cosmos_transfer1.utils import log
24
+
25
+ sys.path.append("cosmos_transfer1/auxiliary")
26
+
27
+ import tempfile
28
+
29
+ from PIL import Image
30
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
31
+ from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
32
+
33
+ from cosmos_transfer1.auxiliary.sam2.sam2_utils import (
34
+ capture_fps,
35
+ convert_masks_to_frames,
36
+ generate_tensor_from_images,
37
+ video_to_frames,
38
+ write_video,
39
+ )
40
+ from cosmos_transfer1.checkpoints import GROUNDING_DINO_MODEL_CHECKPOINT, SAM2_MODEL_CHECKPOINT
41
+
42
+
43
+ def rle_encode(mask: np.ndarray) -> dict:
44
+ """
45
+ Encode a boolean mask (of shape (T, H, W)) using the pycocotools RLE format,
46
+ matching the format of eff_segmentation.RleMaskSAMv2 (from Yotta).
47
+
48
+ The procedure is:
49
+ 1. Convert the mask to a numpy array in Fortran order.
50
+ 2. Reshape the array to (-1, 1) (i.e. flatten in Fortran order).
51
+ 3. Call pycocotools.mask.encode on the reshaped array.
52
+ 4. Return a dictionary with the encoded data and the original mask shape.
53
+ """
54
+ mask = np.array(mask, order="F")
55
+ # Reshape the mask to (-1, 1) in Fortran order and encode it.
56
+ encoded = mask_util.encode(np.array(mask.reshape(-1, 1), order="F"))
57
+ return {"data": encoded, "mask_shape": mask.shape}
58
+
59
+
60
+ class VideoSegmentationModel:
61
+ def __init__(self, **kwargs):
62
+ """Initialize the model and load all required components."""
63
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+
65
+ # Initialize SAM2 predictor
66
+ self.sam2_predictor = SAM2VideoPredictor.from_pretrained(SAM2_MODEL_CHECKPOINT).to(self.device)
67
+
68
+ # Initialize GroundingDINO for text-based detection
69
+ self.grounding_model_name = kwargs.get("grounding_model", GROUNDING_DINO_MODEL_CHECKPOINT)
70
+ self.processor = AutoProcessor.from_pretrained(self.grounding_model_name)
71
+ self.grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(self.grounding_model_name).to(
72
+ self.device
73
+ )
74
+
75
+ def get_boxes_from_text(self, image_path, text_prompt):
76
+ """Get bounding boxes (and labels) from a text prompt using GroundingDINO."""
77
+ image = Image.open(image_path).convert("RGB")
78
+
79
+ inputs = self.processor(images=image, text=text_prompt, return_tensors="pt").to(self.device)
80
+
81
+ with torch.no_grad():
82
+ outputs = self.grounding_model(**inputs)
83
+
84
+ # Try with initial thresholds.
85
+ results = self.processor.post_process_grounded_object_detection(
86
+ outputs,
87
+ inputs.input_ids,
88
+ box_threshold=0.15,
89
+ text_threshold=0.25,
90
+ target_sizes=[image.size[::-1]],
91
+ )
92
+
93
+ boxes = results[0]["boxes"].cpu().numpy()
94
+ scores = results[0]["scores"].cpu().numpy()
95
+ labels = results[0].get("labels", None)
96
+ if len(boxes) == 0:
97
+ print(f"No boxes detected for prompt: '{text_prompt}'. Trying with lower thresholds...")
98
+ results = self.processor.post_process_grounded_object_detection(
99
+ outputs,
100
+ inputs.input_ids,
101
+ box_threshold=0.1,
102
+ text_threshold=0.1,
103
+ target_sizes=[image.size[::-1]],
104
+ )
105
+ boxes = results[0]["boxes"].cpu().numpy()
106
+ scores = results[0]["scores"].cpu().numpy()
107
+ labels = results[0].get("labels", None)
108
+
109
+ if len(boxes) > 0:
110
+ print(f"Found {len(boxes)} boxes with scores: {scores}")
111
+ # Sort boxes by confidence score in descending order
112
+ sorted_indices = np.argsort(scores)[::-1]
113
+ boxes = boxes[sorted_indices]
114
+ scores = scores[sorted_indices]
115
+ if labels is not None:
116
+ labels = np.array(labels)[sorted_indices]
117
+ else:
118
+ print("Still no boxes detected. Consider adjusting the prompt or using box/points mode.")
119
+
120
+ return {"boxes": boxes, "labels": labels, "scores": scores}
121
+
122
+ def visualize_frame(self, frame_idx, obj_ids, masks, video_dir, frame_names, visualization_data, save_dir=None):
123
+ """
124
+ Process a single frame: load the image, apply the segmentation mask to black out the
125
+ detected object(s), and save both the masked frame and the binary mask image.
126
+ """
127
+ # Load the frame.
128
+ frame_path = os.path.join(video_dir, frame_names[frame_idx])
129
+ img = Image.open(frame_path).convert("RGB")
130
+ image_np = np.array(img)
131
+
132
+ # Combine masks from the detection output.
133
+ if isinstance(masks, torch.Tensor):
134
+ mask_np = (masks[0] > 0.0).cpu().numpy().astype(bool)
135
+ combined_mask = mask_np
136
+ elif isinstance(masks, dict):
137
+ first_mask = next(iter(masks.values()))
138
+ combined_mask = np.zeros_like(first_mask, dtype=bool)
139
+ for m in masks.values():
140
+ combined_mask |= m
141
+ else:
142
+ combined_mask = None
143
+
144
+ if combined_mask is not None:
145
+ combined_mask = np.squeeze(combined_mask)
146
+
147
+ # If the mask shape doesn't match the image, resize it.
148
+ if combined_mask.shape != image_np.shape[:2]:
149
+ mask_img = Image.fromarray((combined_mask.astype(np.uint8)) * 255)
150
+ mask_img = mask_img.resize((image_np.shape[1], image_np.shape[0]), resample=Image.NEAREST)
151
+ combined_mask = np.array(mask_img) > 127
152
+
153
+ # Black out the detected region.
154
+ image_np[combined_mask] = 0
155
+
156
+ mask_image = (combined_mask.astype(np.uint8)) * 255
157
+ mask_pil = Image.fromarray(mask_image)
158
+
159
+ if save_dir:
160
+ seg_frame_path = os.path.join(save_dir, f"frame_{frame_idx}_segmented.png")
161
+ seg_pil = Image.fromarray(image_np)
162
+ seg_pil.save(seg_frame_path)
163
+ if combined_mask is not None:
164
+ mask_save_path = os.path.join(save_dir, f"frame_{frame_idx}_mask.png")
165
+ mask_pil.save(mask_save_path)
166
+
167
+ def sample(self, **kwargs):
168
+ """
169
+ Main sampling function for video segmentation.
170
+ Returns a list of detections in which each detection contains a phrase and
171
+ an RLE-encoded segmentation mask (matching the output of the Grounded SAM model).
172
+ """
173
+ video_dir = kwargs.get("video_dir", "")
174
+ mode = kwargs.get("mode", "points")
175
+ input_data = kwargs.get("input_data", None)
176
+ save_dir = kwargs.get("save_dir", None)
177
+ visualize = kwargs.get("visualize", False)
178
+
179
+ # Get frame names (expecting frames named as numbers with .jpg/.jpeg extension).
180
+ frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg"]]
181
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
182
+
183
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
184
+ state = self.sam2_predictor.init_state(video_path=video_dir)
185
+
186
+ ann_frame_idx = 0
187
+ ann_obj_id = 1
188
+ boxes = None
189
+ points = None
190
+ labels = None
191
+ box = None
192
+
193
+ visualization_data = {"mode": mode, "points": None, "labels": None, "box": None, "boxes": None}
194
+
195
+ if input_data is not None:
196
+ if mode == "points":
197
+ points = input_data.get("points")
198
+ labels = input_data.get("labels")
199
+ frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
200
+ inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels
201
+ )
202
+ visualization_data["points"] = points
203
+ visualization_data["labels"] = labels
204
+ elif mode == "box":
205
+ box = input_data.get("box")
206
+ frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
207
+ inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=box
208
+ )
209
+ visualization_data["box"] = box
210
+ elif mode == "prompt":
211
+ text = input_data.get("text")
212
+ first_frame_path = os.path.join(video_dir, frame_names[0])
213
+ gd_results = self.get_boxes_from_text(first_frame_path, text)
214
+ boxes = gd_results["boxes"]
215
+ labels_out = gd_results["labels"]
216
+ scores = gd_results["scores"]
217
+ log.info(f"scores: {scores}")
218
+ if len(boxes) > 0:
219
+ legacy_mask = kwargs.get("legacy_mask", False)
220
+ if legacy_mask:
221
+ # Use only the highest confidence box for legacy mask
222
+ log.info(f"using legacy_mask: {legacy_mask}")
223
+ frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
224
+ inference_state=state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, box=boxes[0]
225
+ )
226
+ # Update boxes and labels after processing
227
+ boxes = boxes[:1]
228
+ if labels_out is not None:
229
+ labels_out = labels_out[:1]
230
+ else:
231
+ log.info(f"using new_mask: {legacy_mask}")
232
+ for object_id, (box, label) in enumerate(zip(boxes, labels_out)):
233
+ frame_idx, obj_ids, masks = self.sam2_predictor.add_new_points_or_box(
234
+ inference_state=state, frame_idx=ann_frame_idx, obj_id=object_id, box=box
235
+ )
236
+ visualization_data["boxes"] = boxes
237
+ self.grounding_labels = [str(lbl) for lbl in labels_out] if labels_out is not None else [text]
238
+ else:
239
+ print("No boxes detected. Exiting.")
240
+ return [] # Return empty list if no detections
241
+
242
+ if visualize:
243
+ self.visualize_frame(
244
+ frame_idx=ann_frame_idx,
245
+ obj_ids=obj_ids,
246
+ masks=masks,
247
+ video_dir=video_dir,
248
+ frame_names=frame_names,
249
+ visualization_data=visualization_data,
250
+ save_dir=save_dir,
251
+ )
252
+
253
+ video_segments = {} # keys: frame index, values: {obj_id: mask}
254
+ for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_predictor.propagate_in_video(state):
255
+ video_segments[out_frame_idx] = {
256
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids)
257
+ }
258
+
259
+ # For propagated frames, visualization_data is not used.
260
+ if visualize:
261
+ propagate_visualization_data = {
262
+ "mode": mode,
263
+ "points": None,
264
+ "labels": None,
265
+ "box": None,
266
+ "boxes": None,
267
+ }
268
+ self.visualize_frame(
269
+ frame_idx=out_frame_idx,
270
+ obj_ids=out_obj_ids,
271
+ masks=video_segments[out_frame_idx],
272
+ video_dir=video_dir,
273
+ frame_names=frame_names,
274
+ visualization_data=propagate_visualization_data,
275
+ save_dir=save_dir,
276
+ )
277
+
278
+ # --- Post-process video_segments to produce a list of detections ---
279
+ if len(video_segments) == 0:
280
+ return []
281
+
282
+ first_frame_path = os.path.join(video_dir, frame_names[0])
283
+ first_frame = np.array(Image.open(first_frame_path).convert("RGB"))
284
+ original_shape = first_frame.shape[:2] # (height, width)
285
+
286
+ object_masks = {} # key: obj_id, value: list of 2D boolean masks
287
+ sorted_frame_indices = sorted(video_segments.keys())
288
+ for frame_idx in sorted_frame_indices:
289
+ segments = video_segments[frame_idx]
290
+ for obj_id, mask in segments.items():
291
+ mask = np.squeeze(mask)
292
+ if mask.ndim != 2:
293
+ print(f"Warning: Unexpected mask shape {mask.shape} for object {obj_id} in frame {frame_idx}.")
294
+ continue
295
+
296
+ if mask.shape != original_shape:
297
+ mask_img = Image.fromarray(mask.astype(np.uint8) * 255)
298
+ mask_img = mask_img.resize((original_shape[1], original_shape[0]), resample=Image.NEAREST)
299
+ mask = np.array(mask_img) > 127
300
+
301
+ if obj_id not in object_masks:
302
+ object_masks[obj_id] = []
303
+ object_masks[obj_id].append(mask)
304
+
305
+ detections = []
306
+ for obj_id, mask_list in object_masks.items():
307
+ mask_stack = np.stack(mask_list, axis=0) # shape: (T, H, W)
308
+ # Use our new rle_encode (which now follows the eff_segmentation.RleMaskSAMv2 format)
309
+ rle = rle_encode(mask_stack)
310
+ if mode == "prompt" and hasattr(self, "grounding_labels"):
311
+ phrase = self.grounding_labels[0]
312
+ else:
313
+ phrase = input_data.get("text", "")
314
+ detection = {"phrase": phrase, "segmentation_mask_rle": rle}
315
+ detections.append(detection)
316
+
317
+ return detections
318
+
319
+ @staticmethod
320
+ def parse_points(points_str):
321
+ """Parse a string of points into a numpy array.
322
+ Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150').
323
+ """
324
+ points = []
325
+ for point in points_str.split(";"):
326
+ coords = point.split(",")
327
+ if len(coords) != 2:
328
+ continue
329
+ points.append([float(coords[0]), float(coords[1])])
330
+ return np.array(points, dtype=np.float32)
331
+
332
+ @staticmethod
333
+ def parse_labels(labels_str):
334
+ """Parse a comma-separated string of labels into a numpy array."""
335
+ return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32)
336
+
337
+ @staticmethod
338
+ def parse_box(box_str):
339
+ """Parse a comma-separated string of 4 box coordinates into a numpy array."""
340
+ return np.array([float(x) for x in box_str.split(",")], dtype=np.float32)
341
+
342
+ def __call__(
343
+ self,
344
+ input_video,
345
+ output_video=None,
346
+ output_tensor=None,
347
+ prompt=None,
348
+ box=None,
349
+ points=None,
350
+ labels=None,
351
+ weight_scaler=None,
352
+ binarize_video=False,
353
+ legacy_mask=False,
354
+ ):
355
+ log.info(
356
+ f"Processing video: {input_video} to generate segmentation video: {output_video} segmentation tensor: {output_tensor}"
357
+ )
358
+ assert os.path.exists(input_video)
359
+
360
+ # Prepare input data based on the selected mode.
361
+ if points is not None:
362
+ mode = "points"
363
+ input_data = {"points": self.parse_points(points), "labels": self.parse_labels(labels)}
364
+ elif box is not None:
365
+ mode = "box"
366
+ input_data = {"box": self.parse_box(box)}
367
+ elif prompt is not None:
368
+ mode = "prompt"
369
+ input_data = {"text": prompt}
370
+
371
+ with tempfile.TemporaryDirectory() as temp_input_dir:
372
+ fps = capture_fps(input_video)
373
+ video_to_frames(input_video, temp_input_dir)
374
+ with tempfile.TemporaryDirectory() as temp_output_dir:
375
+ masks = self.sample(
376
+ video_dir=temp_input_dir,
377
+ mode=mode,
378
+ input_data=input_data,
379
+ save_dir=str(temp_output_dir),
380
+ visualize=True,
381
+ legacy_mask=legacy_mask,
382
+ )
383
+ if output_video:
384
+ os.makedirs(os.path.dirname(output_video), exist_ok=True)
385
+ frames = convert_masks_to_frames(masks)
386
+ if binarize_video:
387
+ frames = np.any(frames > 0, axis=-1).astype(np.uint8) * 255
388
+ write_video(frames, output_video, fps)
389
+ if output_tensor:
390
+ generate_tensor_from_images(
391
+ temp_output_dir, output_tensor, fps, "mask", weight_scaler=weight_scaler
392
+ )
cosmos_transfer1/auxiliary/sam2/sam2_pipeline.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import tempfile
18
+
19
+ import numpy as np
20
+
21
+ from cosmos_transfer1.auxiliary.sam2.sam2_model import VideoSegmentationModel
22
+ from cosmos_transfer1.auxiliary.sam2.sam2_utils import (
23
+ capture_fps,
24
+ generate_tensor_from_images,
25
+ generate_video_from_images,
26
+ video_to_frames,
27
+ )
28
+
29
+
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser(description="Video Segmentation using SAM2")
32
+ parser.add_argument("--input_video", type=str, required=True, help="Path to input video file")
33
+ parser.add_argument(
34
+ "--output_video", type=str, default="./outputs/output_video.mp4", help="Path to save the output video"
35
+ )
36
+ parser.add_argument(
37
+ "--output_tensor", type=str, default="./outputs/output_tensor.pt", help="Path to save the output tensor"
38
+ )
39
+ parser.add_argument(
40
+ "--mode", type=str, choices=["points", "box", "prompt"], default="points", help="Segmentation mode"
41
+ )
42
+ parser.add_argument("--prompt", type=str, help="Text prompt for prompt mode")
43
+ parser.add_argument(
44
+ "--grounding_model_path",
45
+ type=str,
46
+ default="IDEA-Research/grounding-dino-tiny",
47
+ help="Local directory for GroundingDINO model files",
48
+ )
49
+ parser.add_argument(
50
+ "--points",
51
+ type=str,
52
+ default="200,300",
53
+ help="Comma-separated point coordinates for points mode (e.g., '200,300' or for multiple points use ';' as a separator, e.g., '200,300;100,150').",
54
+ )
55
+ parser.add_argument(
56
+ "--labels",
57
+ type=str,
58
+ default="1",
59
+ help="Comma-separated labels for points mode (e.g., '1' or '1,0' for multiple points).",
60
+ )
61
+ parser.add_argument(
62
+ "--box",
63
+ type=str,
64
+ default="300,0,500,400",
65
+ help="Comma-separated box coordinates for box mode (e.g., '300,0,500,400').",
66
+ )
67
+ # New flag to control visualization.
68
+ parser.add_argument("--visualize", action="store_true", help="If set, visualize segmentation frames (save images)")
69
+ return parser.parse_args()
70
+
71
+
72
+ def parse_points(points_str):
73
+ """Parse a string of points into a numpy array.
74
+ Supports a single point ('200,300') or multiple points separated by ';' (e.g., '200,300;100,150').
75
+ """
76
+ points = []
77
+ for point in points_str.split(";"):
78
+ coords = point.split(",")
79
+ if len(coords) != 2:
80
+ continue
81
+ points.append([float(coords[0]), float(coords[1])])
82
+ return np.array(points, dtype=np.float32)
83
+
84
+
85
+ def parse_labels(labels_str):
86
+ """Parse a comma-separated string of labels into a numpy array."""
87
+ return np.array([int(x) for x in labels_str.split(",")], dtype=np.int32)
88
+
89
+
90
+ def parse_box(box_str):
91
+ """Parse a comma-separated string of 4 box coordinates into a numpy array."""
92
+ return np.array([float(x) for x in box_str.split(",")], dtype=np.float32)
93
+
94
+
95
+ def main():
96
+ args = parse_args()
97
+
98
+ # Initialize the segmentation model.
99
+ model = VideoSegmentationModel(**vars(args))
100
+
101
+ # Prepare input data based on the selected mode.
102
+ if args.mode == "points":
103
+ input_data = {"points": parse_points(args.points), "labels": parse_labels(args.labels)}
104
+ elif args.mode == "box":
105
+ input_data = {"box": parse_box(args.box)}
106
+ elif args.mode == "prompt":
107
+ input_data = {"text": args.prompt}
108
+
109
+ with tempfile.TemporaryDirectory() as temp_input_dir:
110
+ fps = capture_fps(args.input_video)
111
+ video_to_frames(args.input_video, temp_input_dir)
112
+ with tempfile.TemporaryDirectory() as temp_output_dir:
113
+ model.sample(
114
+ video_dir=temp_input_dir,
115
+ mode=args.mode,
116
+ input_data=input_data,
117
+ save_dir=str(temp_output_dir),
118
+ visualize=True,
119
+ )
120
+ generate_video_from_images(temp_output_dir, args.output_video, fps)
121
+ generate_tensor_from_images(temp_output_dir, args.output_tensor, fps, "mask")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ print("Starting video segmentation...")
126
+ main()
cosmos_transfer1/auxiliary/sam2/sam2_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import tempfile
18
+ import time
19
+
20
+ import cv2
21
+ import imageio
22
+ import numpy as np
23
+ import pycocotools.mask
24
+ import torch
25
+ from natsort import natsorted
26
+ from PIL import Image
27
+ from torchvision import transforms
28
+
29
+ from cosmos_transfer1.diffusion.datasets.augmentors.control_input import (
30
+ decode_partial_rle_width1,
31
+ segmentation_color_mask,
32
+ )
33
+ from cosmos_transfer1.utils import log
34
+
35
+
36
+ def write_video(frames, output_path, fps=30):
37
+ """
38
+ expects a sequence of [H, W, 3] or [H, W] frames
39
+ """
40
+ with imageio.get_writer(output_path, fps=fps, macro_block_size=8) as writer:
41
+ for frame in frames:
42
+ if len(frame.shape) == 2: # single channel
43
+ frame = frame[:, :, None].repeat(3, axis=2)
44
+ writer.append_data(frame)
45
+
46
+
47
+ def capture_fps(input_video_path: str):
48
+ cap = cv2.VideoCapture(input_video_path)
49
+ fps = cap.get(cv2.CAP_PROP_FPS)
50
+ return fps
51
+
52
+
53
+ def video_to_frames(input_loc, output_loc):
54
+ """Function to extract frames from input video file
55
+ and save them as separate frames in an output directory.
56
+ Args:
57
+ input_loc: Input video file.
58
+ output_loc: Output directory to save the frames.
59
+ Returns:
60
+ None
61
+ """
62
+ try:
63
+ os.mkdir(output_loc)
64
+ except OSError:
65
+ pass
66
+ # Log the time
67
+ time_start = time.time()
68
+ # Start capturing the feed
69
+ cap = cv2.VideoCapture(input_loc)
70
+ # Find the number of frames
71
+ video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
72
+ print(f"Number of frames: {video_length}")
73
+ count = 0
74
+ print("Converting video..\n")
75
+ # Start converting the video
76
+ while cap.isOpened():
77
+ # Extract the frame
78
+ ret, frame = cap.read()
79
+ if not ret:
80
+ continue
81
+ # Write the results back to output location.
82
+ cv2.imwrite(output_loc + "/%#05d.jpg" % (count + 1), frame)
83
+ count = count + 1
84
+ # If there are no more frames left
85
+ if count > (video_length - 1):
86
+ # Log the time again
87
+ time_end = time.time()
88
+ # Release the feed
89
+ cap.release()
90
+ # Print stats
91
+ print("Done extracting frames.\n%d frames extracted" % count)
92
+ print("It took %d seconds forconversion." % (time_end - time_start))
93
+ break
94
+
95
+
96
+ # Function to generate video
97
+ def convert_masks_to_frames(masks: list, num_masks_max: int = 100):
98
+ T, H, W = shape = masks[0]["segmentation_mask_rle"]["mask_shape"]
99
+ frame_start, frame_end = 0, T
100
+ num_masks = min(num_masks_max, len(masks))
101
+ mask_ids_select = np.arange(num_masks).tolist()
102
+
103
+ all_masks = np.zeros((num_masks, T, H, W), dtype=np.uint8)
104
+ for idx, mid in enumerate(mask_ids_select):
105
+ mask = masks[mid]
106
+ num_byte_per_mb = 1024 * 1024
107
+ # total number of elements in uint8 (1 byte) / num_byte_per_mb
108
+ if shape[0] * shape[1] * shape[2] / num_byte_per_mb > 256:
109
+ rle = decode_partial_rle_width1(
110
+ mask["segmentation_mask_rle"]["data"],
111
+ frame_start * shape[1] * shape[2],
112
+ frame_end * shape[1] * shape[2],
113
+ )
114
+ partial_shape = (frame_end - frame_start, shape[1], shape[2])
115
+ rle = rle.reshape(partial_shape) * 255
116
+ else:
117
+ rle = pycocotools.mask.decode(mask["segmentation_mask_rle"]["data"])
118
+ rle = rle.reshape(shape) * 255
119
+ # Select the frames that are in the video
120
+ frame_indices = np.arange(frame_start, frame_end).tolist()
121
+ rle = np.stack([rle[i] for i in frame_indices])
122
+ all_masks[idx] = rle
123
+ del rle
124
+
125
+ all_masks = segmentation_color_mask(all_masks) # NTHW -> 3THW
126
+ all_masks = all_masks.transpose(1, 2, 3, 0)
127
+ return all_masks
128
+
129
+
130
+ def generate_video_from_images(masks: list, output_file_path: str, fps, num_masks_max: int = 100):
131
+ all_masks = convert_masks_to_frames(masks, num_masks_max)
132
+ write_video(all_masks, output_file_path, fps)
133
+ print("Video generated successfully!")
134
+
135
+
136
+ def generate_tensor_from_images(
137
+ image_path_str: str, output_file_path: str, fps, search_pattern: str = None, weight_scaler: float = None
138
+ ):
139
+ images = list()
140
+ image_path = os.path.abspath(image_path_str)
141
+ if search_pattern is None:
142
+ images = [img for img in natsorted(os.listdir(image_path))]
143
+ else:
144
+ for img in natsorted(os.listdir(image_path)):
145
+ if img.__contains__(search_pattern):
146
+ images.append(img)
147
+
148
+ transform = transforms.ToTensor()
149
+ image_tensors = list()
150
+ for image in images:
151
+ img_tensor = transform(Image.open(os.path.join(image_path, image)))
152
+ image_tensors.append(img_tensor.squeeze(0))
153
+
154
+ tensor = torch.stack(image_tensors) # [T, H, W], binary values, float
155
+
156
+ if weight_scaler is not None:
157
+ log.info(f"scaling the tensor by the specified scale: {weight_scaler}")
158
+ tensor = tensor * weight_scaler
159
+
160
+ log.info(f"saving tensor shape: {tensor.shape} to {output_file_path}")
161
+ torch.save(tensor, output_file_path)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ input_loc = "cosmos_transfer1/models/sam2/assets/input_video.mp4"
166
+ output_loc = os.path.abspath(tempfile.TemporaryDirectory().name)
167
+ print(f"output_loc --- {output_loc}")
168
+ video_to_frames(input_loc, output_loc)
cosmos_transfer1/auxiliary/tokenizer/inference/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_transfer1/auxiliary/tokenizer/inference/image_cli.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A CLI to run ImageTokenizer on plain images based on torch.jit.
17
+
18
+ Usage:
19
+ python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.image_cli \
20
+ --image_pattern 'path/to/input/folder/*.jpg' \
21
+ --output_dir ./reconstructions \
22
+ --checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
23
+ --checkpoint_dec ./checkpoints/<model-name>/decoder.jit
24
+
25
+ Optionally, you can run the model in pure PyTorch mode:
26
+ python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.image_cli \
27
+ --image_pattern 'path/to/input/folder/*.jpg' \
28
+ --mode torch \
29
+ --tokenizer_type CI \
30
+ --spatial_compression 8 \
31
+ --checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
32
+ --checkpoint_dec ./checkpoints/<model-name>/decoder.jit
33
+ """
34
+
35
+ import os
36
+ import sys
37
+ from argparse import ArgumentParser, Namespace
38
+ from typing import Any
39
+
40
+ import numpy as np
41
+ from loguru import logger as logging
42
+
43
+ from cosmos_transfer1.auxiliary.tokenizer.inference.image_lib import ImageTokenizer
44
+ from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
45
+ get_filepaths,
46
+ get_output_filepath,
47
+ read_image,
48
+ resize_image,
49
+ write_image,
50
+ )
51
+ from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerConfigs
52
+
53
+
54
+ def _parse_args() -> tuple[Namespace, dict[str, Any]]:
55
+ parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.")
56
+ parser.add_argument(
57
+ "--image_pattern",
58
+ type=str,
59
+ default="path/to/images/*.jpg",
60
+ help="Glob pattern.",
61
+ )
62
+ parser.add_argument(
63
+ "--checkpoint",
64
+ type=str,
65
+ default=None,
66
+ help="JIT full Autoencoder model filepath.",
67
+ )
68
+ parser.add_argument(
69
+ "--checkpoint_enc",
70
+ type=str,
71
+ default=None,
72
+ help="JIT Encoder model filepath.",
73
+ )
74
+ parser.add_argument(
75
+ "--checkpoint_dec",
76
+ type=str,
77
+ default=None,
78
+ help="JIT Decoder model filepath.",
79
+ )
80
+ parser.add_argument(
81
+ "--tokenizer_type",
82
+ type=str,
83
+ choices=["CI", "DI"],
84
+ help="Specifies the tokenizer type.",
85
+ )
86
+ parser.add_argument(
87
+ "--spatial_compression",
88
+ type=int,
89
+ choices=[8, 16],
90
+ default=8,
91
+ help="The spatial compression factor.",
92
+ )
93
+ parser.add_argument(
94
+ "--mode",
95
+ type=str,
96
+ choices=["torch", "jit"],
97
+ default="jit",
98
+ help="Specify the backend: native 'torch' or 'jit' (default: 'jit')",
99
+ )
100
+ parser.add_argument(
101
+ "--short_size",
102
+ type=int,
103
+ default=None,
104
+ help="The size to resample inputs. None, by default.",
105
+ )
106
+ parser.add_argument(
107
+ "--dtype",
108
+ type=str,
109
+ default="bfloat16",
110
+ help="Sets the precision. Default bfloat16.",
111
+ )
112
+ parser.add_argument(
113
+ "--device",
114
+ type=str,
115
+ default="cuda",
116
+ help="Device for invoking the model.",
117
+ )
118
+ parser.add_argument("--output_dir", type=str, default=None, help="Output directory.")
119
+ parser.add_argument(
120
+ "--save_input",
121
+ action="store_true",
122
+ help="If on, the input image will be be outputed too.",
123
+ )
124
+ args = parser.parse_args()
125
+ return args
126
+
127
+
128
+ logging.info("Initializes args ...")
129
+ args = _parse_args()
130
+ if args.mode == "torch" and args.tokenizer_type not in ["CI", "DI"]:
131
+ logging.error("'torch' backend requires the tokenizer_type of 'CI' or 'DI'.")
132
+ sys.exit(1)
133
+
134
+
135
+ def _run_eval() -> None:
136
+ """Invokes the evaluation pipeline."""
137
+
138
+ if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None:
139
+ logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.")
140
+ return
141
+
142
+ if args.mode == "torch":
143
+ tokenizer_config = TokenizerConfigs[args.tokenizer_type].value
144
+ tokenizer_config.update(dict(spatial_compression=args.spatial_compression))
145
+ else:
146
+ tokenizer_config = None
147
+
148
+ logging.info(
149
+ f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..."
150
+ )
151
+ autoencoder = ImageTokenizer(
152
+ checkpoint=args.checkpoint,
153
+ checkpoint_enc=args.checkpoint_enc,
154
+ checkpoint_dec=args.checkpoint_dec,
155
+ tokenizer_config=tokenizer_config,
156
+ device=args.device,
157
+ dtype=args.dtype,
158
+ )
159
+
160
+ filepaths = get_filepaths(args.image_pattern)
161
+ logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.")
162
+
163
+ for filepath in filepaths:
164
+ logging.info(f"Reading image {filepath} ...")
165
+ image = read_image(filepath)
166
+ image = resize_image(image, short_size=args.short_size)
167
+ batch_image = np.expand_dims(image, axis=0)
168
+
169
+ logging.info("Invoking the autoencoder model in ... ")
170
+ output_image = autoencoder(batch_image)[0]
171
+
172
+ output_filepath = get_output_filepath(filepath, output_dir=args.output_dir)
173
+ logging.info(f"Outputing {output_filepath} ...")
174
+ write_image(output_filepath, output_image)
175
+
176
+ if args.save_input:
177
+ ext = os.path.splitext(output_filepath)[-1]
178
+ input_filepath = output_filepath.replace(ext, "_input" + ext)
179
+ write_image(input_filepath, image)
180
+
181
+
182
+ @logging.catch(reraise=True)
183
+ def main() -> None:
184
+ _run_eval()
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
cosmos_transfer1/auxiliary/tokenizer/inference/image_lib.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A library for image tokenizers inference."""
17
+
18
+ from typing import Any
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
24
+ load_decoder_model,
25
+ load_encoder_model,
26
+ load_model,
27
+ numpy2tensor,
28
+ pad_image_batch,
29
+ tensor2numpy,
30
+ unpad_image_batch,
31
+ )
32
+
33
+
34
+ class ImageTokenizer(torch.nn.Module):
35
+ def __init__(
36
+ self,
37
+ checkpoint: str = None,
38
+ checkpoint_enc: str = None,
39
+ checkpoint_dec: str = None,
40
+ tokenizer_config: dict[str, Any] = None,
41
+ device: str = "cuda",
42
+ dtype: str = "bfloat16",
43
+ ) -> None:
44
+ super().__init__()
45
+ self._device = device
46
+ self._dtype = getattr(torch, dtype)
47
+ self._full_model = (
48
+ load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None
49
+ )
50
+ self._enc_model = (
51
+ load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
52
+ if checkpoint_enc is not None
53
+ else None
54
+ )
55
+ self._dec_model = (
56
+ load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
57
+ if checkpoint_dec is not None
58
+ else None
59
+ )
60
+
61
+ @torch.no_grad()
62
+ def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
63
+ """Reconstrcuts a batch of image tensors after embedding into a latent.
64
+
65
+ Args:
66
+ input_tensor: The input image Bx3xHxW layout, range [-1..1].
67
+ Returns:
68
+ The reconstructed tensor, layout Bx3xHxW, range [-1..1].
69
+ """
70
+ if self._full_model is not None:
71
+ output_tensor = self._full_model(input_tensor)
72
+ output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
73
+ else:
74
+ output_latent = self.encode(input_tensor)[0]
75
+ output_tensor = self.decode(output_latent)
76
+ return output_tensor
77
+
78
+ @torch.no_grad()
79
+ def decode(self, input_latent: torch.Tensor) -> torch.Tensor:
80
+ """Decodes an image from a provided latent embedding.
81
+
82
+ Args:
83
+ input_latent: The continuous latent Bx16xhxw for CI,
84
+ or the discrete indices Bxhxw for DI.
85
+ Returns:
86
+ The output tensor in Bx3xHxW, range [-1..1].
87
+ """
88
+ return self._dec_model(input_latent)
89
+
90
+ @torch.no_grad()
91
+ def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
92
+ """Encodes an image into a latent embedding or code.
93
+
94
+ Args:
95
+ input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
96
+ Returns:
97
+ For continuous image (CI) tokenizer, the tuple contains:
98
+ - The latent embedding, Bx16x(h)x(w), where the compression
99
+ rate is (H/h x W/w), and channel dimension of 16.
100
+ For discrete image (DI) tokenizer, the tuple contains:
101
+ - The indices, Bx(h)x(w), from a codebook of size 64K, which
102
+ corresponds to FSQ levels of (8,8,8,5,5,5).
103
+ - The discrete code, Bx6x(h)x(w), where the compression rate is
104
+ again (H/h x W/w), and channel dimension of 6.
105
+ """
106
+ output_latent = self._enc_model(input_tensor)
107
+ if isinstance(output_latent, torch.Tensor):
108
+ return output_latent
109
+ return output_latent[:-1]
110
+
111
+ @torch.no_grad()
112
+ def forward(self, image: np.ndarray) -> np.ndarray:
113
+ """Reconstructs an image using a pre-trained tokenizer.
114
+
115
+ Args:
116
+ image: The input image BxHxWxC layout, range [0..255].
117
+ Returns:
118
+ The reconstructed image in range [0..255], layout BxHxWxC.
119
+ """
120
+ padded_input_image, crop_region = pad_image_batch(image)
121
+ input_tensor = numpy2tensor(padded_input_image, dtype=self._dtype, device=self._device)
122
+ output_tensor = self.autoencode(input_tensor)
123
+ padded_output_image = tensor2numpy(output_tensor)
124
+ return unpad_image_batch(padded_output_image, crop_region)
cosmos_transfer1/auxiliary/tokenizer/inference/utils.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Utility functions for the inference libraries."""
17
+
18
+ import os
19
+ from glob import glob
20
+ from typing import Any
21
+
22
+ import mediapy as media
23
+ import numpy as np
24
+ import torch
25
+
26
+ from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerModels
27
+
28
+ _DTYPE, _DEVICE = torch.bfloat16, "cuda"
29
+ _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
30
+ _SPATIAL_ALIGN = 16
31
+ _TEMPORAL_ALIGN = 8
32
+
33
+
34
+ def load_model(
35
+ jit_filepath: str = None,
36
+ tokenizer_config: dict[str, Any] = None,
37
+ device: str = "cuda",
38
+ ) -> torch.nn.Module | torch.jit.ScriptModule:
39
+ """Loads a torch.nn.Module from a filepath.
40
+
41
+ Args:
42
+ jit_filepath: The filepath to the JIT-compiled model.
43
+ device: The device to load the model onto, default=cuda.
44
+ Returns:
45
+ The JIT compiled model loaded to device and on eval mode.
46
+ """
47
+ if tokenizer_config is None:
48
+ return load_jit_model(jit_filepath, device)
49
+ full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
50
+ full_model.load_state_dict(ckpts.state_dict(), strict=False)
51
+ return full_model.eval().to(device)
52
+
53
+
54
+ def load_encoder_model(
55
+ jit_filepath: str = None,
56
+ tokenizer_config: dict[str, Any] = None,
57
+ device: str = "cuda",
58
+ ) -> torch.nn.Module | torch.jit.ScriptModule:
59
+ """Loads a torch.nn.Module from a filepath.
60
+
61
+ Args:
62
+ jit_filepath: The filepath to the JIT-compiled model.
63
+ device: The device to load the model onto, default=cuda.
64
+ Returns:
65
+ The JIT compiled model loaded to device and on eval mode.
66
+ """
67
+ if tokenizer_config is None:
68
+ return load_jit_model(jit_filepath, device)
69
+ full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
70
+ encoder_model = full_model.encoder_jit()
71
+ encoder_model.load_state_dict(ckpts.state_dict(), strict=False)
72
+ return encoder_model.eval().to(device)
73
+
74
+
75
+ def load_decoder_model(
76
+ jit_filepath: str = None,
77
+ tokenizer_config: dict[str, Any] = None,
78
+ device: str = "cuda",
79
+ ) -> torch.nn.Module | torch.jit.ScriptModule:
80
+ """Loads a torch.nn.Module from a filepath.
81
+
82
+ Args:
83
+ jit_filepath: The filepath to the JIT-compiled model.
84
+ device: The device to load the model onto, default=cuda.
85
+ Returns:
86
+ The JIT compiled model loaded to device and on eval mode.
87
+ """
88
+ if tokenizer_config is None:
89
+ return load_jit_model(jit_filepath, device)
90
+ full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
91
+ decoder_model = full_model.decoder_jit()
92
+ decoder_model.load_state_dict(ckpts.state_dict(), strict=False)
93
+ return decoder_model.eval().to(device)
94
+
95
+
96
+ def _load_pytorch_model(
97
+ jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda"
98
+ ) -> torch.nn.Module:
99
+ """Loads a torch.nn.Module from a filepath.
100
+
101
+ Args:
102
+ jit_filepath: The filepath to the JIT-compiled model.
103
+ device: The device to load the model onto, default=cuda.
104
+ Returns:
105
+ The JIT compiled model loaded to device and on eval mode.
106
+ """
107
+ tokenizer_name = tokenizer_config["name"]
108
+ model = TokenizerModels[tokenizer_name].value(**tokenizer_config)
109
+ ckpts = torch.jit.load(jit_filepath, map_location=device)
110
+ return model, ckpts
111
+
112
+
113
+ def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule:
114
+ """Loads a torch.jit.ScriptModule from a filepath.
115
+
116
+ Args:
117
+ jit_filepath: The filepath to the JIT-compiled model.
118
+ device: The device to load the model onto, default=cuda.
119
+ Returns:
120
+ The JIT compiled model loaded to device and on eval mode.
121
+ """
122
+ model = torch.jit.load(jit_filepath, map_location=device)
123
+ return model.eval().to(device)
124
+
125
+
126
+ def save_jit_model(
127
+ model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None,
128
+ jit_filepath: str = None,
129
+ ) -> None:
130
+ """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file.
131
+
132
+ Args:
133
+ model: JIT compiled model loaded onto `config.checkpoint.jit.device`.
134
+ jit_filepath: The filepath to the JIT-compiled model.
135
+ """
136
+ torch.jit.save(model, jit_filepath)
137
+
138
+
139
+ def get_filepaths(input_pattern) -> list[str]:
140
+ """Returns a list of filepaths from a pattern."""
141
+ filepaths = sorted(glob(str(input_pattern)))
142
+ return list(set(filepaths))
143
+
144
+
145
+ def get_output_filepath(filepath: str, output_dir: str = None) -> str:
146
+ """Returns the output filepath for the given input filepath."""
147
+ output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions"
148
+ output_filepath = f"{output_dir}/{os.path.basename(filepath)}"
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ return output_filepath
151
+
152
+
153
+ def read_image(filepath: str) -> np.ndarray:
154
+ """Reads an image from a filepath.
155
+
156
+ Args:
157
+ filepath: The filepath to the image.
158
+
159
+ Returns:
160
+ The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype.
161
+ """
162
+ image = media.read_image(filepath)
163
+ # convert the grey scale image to RGB
164
+ # since our tokenizers always assume 3-channel RGB image
165
+ if image.ndim == 2:
166
+ image = np.stack([image] * 3, axis=-1)
167
+ # convert RGBA to RGB
168
+ if image.shape[-1] == 4:
169
+ image = image[..., :3]
170
+ return image
171
+
172
+
173
+ def read_video(filepath: str) -> np.ndarray:
174
+ """Reads a video from a filepath.
175
+
176
+ Args:
177
+ filepath: The filepath to the video.
178
+ Returns:
179
+ The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype.
180
+ """
181
+ video = media.read_video(filepath)
182
+ # convert the grey scale frame to RGB
183
+ # since our tokenizers always assume 3-channel video
184
+ if video.ndim == 3:
185
+ video = np.stack([video] * 3, axis=-1)
186
+ # convert RGBA to RGB
187
+ if video.shape[-1] == 4:
188
+ video = video[..., :3]
189
+ return video
190
+
191
+
192
+ def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray:
193
+ """Resizes an image to have the short side of `short_size`.
194
+
195
+ Args:
196
+ image: The image to resize, layout HxWxC, of any range.
197
+ short_size: The size of the short side.
198
+ Returns:
199
+ The resized image.
200
+ """
201
+ if short_size is None:
202
+ return image
203
+ height, width = image.shape[-3:-1]
204
+ if height <= width:
205
+ height_new, width_new = short_size, int(width * short_size / height + 0.5)
206
+ width_new = width_new if width_new % 2 == 0 else width_new + 1
207
+ else:
208
+ height_new, width_new = (
209
+ int(height * short_size / width + 0.5),
210
+ short_size,
211
+ )
212
+ height_new = height_new if height_new % 2 == 0 else height_new + 1
213
+ return media.resize_image(image, shape=(height_new, width_new))
214
+
215
+
216
+ def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray:
217
+ """Resizes a video to have the short side of `short_size`.
218
+
219
+ Args:
220
+ video: The video to resize, layout TxHxWxC, of any range.
221
+ short_size: The size of the short side.
222
+ Returns:
223
+ The resized video.
224
+ """
225
+ if short_size is None:
226
+ return video
227
+ height, width = video.shape[-3:-1]
228
+ if height <= width:
229
+ height_new, width_new = short_size, int(width * short_size / height + 0.5)
230
+ width_new = width_new if width_new % 2 == 0 else width_new + 1
231
+ else:
232
+ height_new, width_new = (
233
+ int(height * short_size / width + 0.5),
234
+ short_size,
235
+ )
236
+ height_new = height_new if height_new % 2 == 0 else height_new + 1
237
+ return media.resize_video(video, shape=(height_new, width_new))
238
+
239
+
240
+ def write_image(filepath: str, image: np.ndarray):
241
+ """Writes an image to a filepath."""
242
+ return media.write_image(filepath, image)
243
+
244
+
245
+ def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None:
246
+ """Writes a video to a filepath."""
247
+ return media.write_video(filepath, video, fps=fps)
248
+
249
+
250
+ def numpy2tensor(
251
+ input_image: np.ndarray,
252
+ dtype: torch.dtype = _DTYPE,
253
+ device: str = _DEVICE,
254
+ range_min: int = -1,
255
+ ) -> torch.Tensor:
256
+ """Converts image(dtype=np.uint8) to `dtype` in range [0..255].
257
+
258
+ Args:
259
+ input_image: A batch of images in range [0..255], BxHxWx3 layout.
260
+ Returns:
261
+ A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype.
262
+ """
263
+ ndim = input_image.ndim
264
+ indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1]
265
+ image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F
266
+ if range_min == -1:
267
+ image = 2.0 * image - 1.0
268
+ return torch.from_numpy(image).to(dtype).to(device)
269
+
270
+
271
+ def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray:
272
+ """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255].
273
+
274
+ Args:
275
+ input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1].
276
+ Returns:
277
+ A numpy image of layout BxHxWx3, range [0..255], uint8 dtype.
278
+ """
279
+ if range_min == -1:
280
+ input_tensor = (input_tensor.float() + 1.0) / 2.0
281
+ ndim = input_tensor.ndim
282
+ output_image = input_tensor.clamp(0, 1).cpu().numpy()
283
+ output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,))
284
+ return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8)
285
+
286
+
287
+ def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]:
288
+ """Pads a batch of images to be divisible by `spatial_align`.
289
+
290
+ Args:
291
+ batch: The batch of images to pad, layout BxHxWx3, in any range.
292
+ align: The alignment to pad to.
293
+ Returns:
294
+ The padded batch and the crop region.
295
+ """
296
+ height, width = batch.shape[1:3]
297
+ align = spatial_align
298
+ height_to_pad = (align - height % align) if height % align != 0 else 0
299
+ width_to_pad = (align - width % align) if width % align != 0 else 0
300
+
301
+ crop_region = [
302
+ height_to_pad >> 1,
303
+ width_to_pad >> 1,
304
+ height + (height_to_pad >> 1),
305
+ width + (width_to_pad >> 1),
306
+ ]
307
+ batch = np.pad(
308
+ batch,
309
+ (
310
+ (0, 0),
311
+ (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
312
+ (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
313
+ (0, 0),
314
+ ),
315
+ mode="constant",
316
+ )
317
+ return batch, crop_region
318
+
319
+
320
+ def pad_video_batch(
321
+ batch: np.ndarray,
322
+ temporal_align: int = _TEMPORAL_ALIGN,
323
+ spatial_align: int = _SPATIAL_ALIGN,
324
+ ) -> tuple[np.ndarray, list[int]]:
325
+ """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`.
326
+
327
+ Zero pad spatially. Reflection pad temporally to handle causality better.
328
+ Args:
329
+ batch: The batch of videos to pad., layout BxFxHxWx3, in any range.
330
+ align: The alignment to pad to.
331
+ Returns:
332
+ The padded batch and the crop region.
333
+ """
334
+ num_frames, height, width = batch.shape[-4:-1]
335
+ align = spatial_align
336
+ height_to_pad = (align - height % align) if height % align != 0 else 0
337
+ width_to_pad = (align - width % align) if width % align != 0 else 0
338
+
339
+ align = temporal_align
340
+ frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0
341
+
342
+ crop_region = [
343
+ frames_to_pad >> 1,
344
+ height_to_pad >> 1,
345
+ width_to_pad >> 1,
346
+ num_frames + (frames_to_pad >> 1),
347
+ height + (height_to_pad >> 1),
348
+ width + (width_to_pad >> 1),
349
+ ]
350
+ batch = np.pad(
351
+ batch,
352
+ (
353
+ (0, 0),
354
+ (0, 0),
355
+ (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
356
+ (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
357
+ (0, 0),
358
+ ),
359
+ mode="constant",
360
+ )
361
+ batch = np.pad(
362
+ batch,
363
+ (
364
+ (0, 0),
365
+ (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)),
366
+ (0, 0),
367
+ (0, 0),
368
+ (0, 0),
369
+ ),
370
+ mode="edge",
371
+ )
372
+ return batch, crop_region
373
+
374
+
375
+ def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray:
376
+ """Unpads video with `crop_region`.
377
+
378
+ Args:
379
+ batch: A batch of numpy videos, layout BxFxHxWxC.
380
+ crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices.
381
+
382
+ Returns:
383
+ np.ndarray: Cropped numpy video, layout BxFxHxWxC.
384
+ """
385
+ assert len(crop_region) == 6, "crop_region should be len of 6."
386
+ f1, y1, x1, f2, y2, x2 = crop_region
387
+ return batch[..., f1:f2, y1:y2, x1:x2, :]
388
+
389
+
390
+ def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray:
391
+ """Unpads image with `crop_region`.
392
+
393
+ Args:
394
+ batch: A batch of numpy images, layout BxHxWxC.
395
+ crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices.
396
+
397
+ Returns:
398
+ np.ndarray: Cropped numpy image, layout BxHxWxC.
399
+ """
400
+ assert len(crop_region) == 4, "crop_region should be len of 4."
401
+ y1, x1, y2, x2 = crop_region
402
+ return batch[..., y1:y2, x1:x2, :]
cosmos_transfer1/auxiliary/tokenizer/inference/video_cli.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A CLI to run CausalVideoTokenizer on plain videos based on torch.jit.
17
+
18
+ Usage:
19
+ python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.video_cli \
20
+ --video_pattern 'path/to/video/samples/*.mp4' \
21
+ --output_dir ./reconstructions \
22
+ --checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
23
+ --checkpoint_dec ./checkpoints/<model-name>/decoder.jit
24
+
25
+ Optionally, you can run the model in pure PyTorch mode:
26
+ python3 -m cosmos_transfer1.auxiliary.tokenizer.inference.video_cli \
27
+ --video_pattern 'path/to/video/samples/*.mp4' \
28
+ --mode=torch \
29
+ --tokenizer_type=CV \
30
+ --temporal_compression=4 \
31
+ --spatial_compression=8 \
32
+ --checkpoint_enc ./checkpoints/<model-name>/encoder.jit \
33
+ --checkpoint_dec ./checkpoints/<model-name>/decoder.jit
34
+ """
35
+
36
+ import os
37
+ import sys
38
+ from argparse import ArgumentParser, Namespace
39
+ from typing import Any
40
+
41
+ import numpy as np
42
+ from loguru import logger as logging
43
+
44
+ from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
45
+ get_filepaths,
46
+ get_output_filepath,
47
+ read_video,
48
+ resize_video,
49
+ write_video,
50
+ )
51
+ from cosmos_transfer1.auxiliary.tokenizer.inference.video_lib import CausalVideoTokenizer
52
+ from cosmos_transfer1.auxiliary.tokenizer.networks import TokenizerConfigs
53
+
54
+
55
+ def _parse_args() -> tuple[Namespace, dict[str, Any]]:
56
+ parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.")
57
+ parser.add_argument(
58
+ "--video_pattern",
59
+ type=str,
60
+ default="path/to/videos/*.mp4",
61
+ help="Glob pattern.",
62
+ )
63
+ parser.add_argument(
64
+ "--checkpoint",
65
+ type=str,
66
+ default=None,
67
+ help="JIT full Autoencoder model filepath.",
68
+ )
69
+ parser.add_argument(
70
+ "--checkpoint_enc",
71
+ type=str,
72
+ default=None,
73
+ help="JIT Encoder model filepath.",
74
+ )
75
+ parser.add_argument(
76
+ "--checkpoint_dec",
77
+ type=str,
78
+ default=None,
79
+ help="JIT Decoder model filepath.",
80
+ )
81
+ parser.add_argument(
82
+ "--tokenizer_type",
83
+ type=str,
84
+ choices=["CV", "DV"],
85
+ help="Specifies the tokenizer type.",
86
+ )
87
+ parser.add_argument(
88
+ "--spatial_compression",
89
+ type=int,
90
+ choices=[8, 16],
91
+ default=8,
92
+ help="The spatial compression factor.",
93
+ )
94
+ parser.add_argument(
95
+ "--temporal_compression",
96
+ type=int,
97
+ choices=[4, 8],
98
+ default=4,
99
+ help="The temporal compression factor.",
100
+ )
101
+ parser.add_argument(
102
+ "--mode",
103
+ type=str,
104
+ choices=["torch", "jit"],
105
+ default="jit",
106
+ help="Specify the backend: native 'torch' or 'jit' (default: 'jit')",
107
+ )
108
+ parser.add_argument(
109
+ "--short_size",
110
+ type=int,
111
+ default=None,
112
+ help="The size to resample inputs. None, by default.",
113
+ )
114
+ parser.add_argument(
115
+ "--temporal_window",
116
+ type=int,
117
+ default=17,
118
+ help="The temporal window to operate at a time.",
119
+ )
120
+ parser.add_argument(
121
+ "--dtype",
122
+ type=str,
123
+ default="bfloat16",
124
+ help="Sets the precision, default bfloat16.",
125
+ )
126
+ parser.add_argument(
127
+ "--device",
128
+ type=str,
129
+ default="cuda",
130
+ help="Device for invoking the model.",
131
+ )
132
+ parser.add_argument("--output_dir", type=str, default=None, help="Output directory.")
133
+ parser.add_argument(
134
+ "--output_fps",
135
+ type=float,
136
+ default=24.0,
137
+ help="Output frames-per-second (FPS).",
138
+ )
139
+ parser.add_argument(
140
+ "--save_input",
141
+ action="store_true",
142
+ help="If on, the input video will be be outputted too.",
143
+ )
144
+
145
+ args = parser.parse_args()
146
+ return args
147
+
148
+
149
+ logging.info("Initializes args ...")
150
+ args = _parse_args()
151
+ if args.mode == "torch" and args.tokenizer_type not in ["CV", "DV"]:
152
+ logging.error("'torch' backend requires the tokenizer_type of 'CV' or 'DV'.")
153
+ sys.exit(1)
154
+
155
+
156
+ def _run_eval() -> None:
157
+ """Invokes JIT-compiled CausalVideoTokenizer on an input video."""
158
+
159
+ if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None:
160
+ logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.")
161
+ return
162
+
163
+ if args.mode == "torch":
164
+ tokenizer_config = TokenizerConfigs[args.tokenizer_type].value
165
+ tokenizer_config.update(dict(spatial_compression=args.spatial_compression))
166
+ tokenizer_config.update(dict(temporal_compression=args.temporal_compression))
167
+ else:
168
+ tokenizer_config = None
169
+
170
+ logging.info(
171
+ f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..."
172
+ )
173
+ autoencoder = CausalVideoTokenizer(
174
+ checkpoint=args.checkpoint,
175
+ checkpoint_enc=args.checkpoint_enc,
176
+ checkpoint_dec=args.checkpoint_dec,
177
+ tokenizer_config=tokenizer_config,
178
+ device=args.device,
179
+ dtype=args.dtype,
180
+ )
181
+
182
+ logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...")
183
+ filepaths = get_filepaths(args.video_pattern)
184
+ logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.")
185
+
186
+ for filepath in filepaths:
187
+ logging.info(f"Reading video {filepath} ...")
188
+ video = read_video(filepath)
189
+ video = resize_video(video, short_size=args.short_size)
190
+
191
+ logging.info("Invoking the autoencoder model in ... ")
192
+ batch_video = video[np.newaxis, ...]
193
+ output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0]
194
+ logging.info("Constructing output filepath ...")
195
+ output_filepath = get_output_filepath(filepath, output_dir=args.output_dir)
196
+ logging.info(f"Outputing {output_filepath} ...")
197
+ write_video(output_filepath, output_video, fps=args.output_fps)
198
+ if args.save_input:
199
+ ext = os.path.splitext(output_filepath)[-1]
200
+ input_filepath = output_filepath.replace(ext, "_input" + ext)
201
+ write_video(input_filepath, video, fps=args.output_fps)
202
+
203
+
204
+ @logging.catch(reraise=True)
205
+ def main() -> None:
206
+ _run_eval()
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
cosmos_transfer1/auxiliary/tokenizer/inference/video_lib.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A library for Causal Video Tokenizer inference."""
17
+
18
+ from typing import Any
19
+
20
+ import numpy as np
21
+ import torch
22
+ from tqdm import tqdm
23
+
24
+ from cosmos_transfer1.auxiliary.tokenizer.inference.utils import (
25
+ load_decoder_model,
26
+ load_encoder_model,
27
+ load_model,
28
+ numpy2tensor,
29
+ pad_video_batch,
30
+ tensor2numpy,
31
+ unpad_video_batch,
32
+ )
33
+
34
+
35
+ class CausalVideoTokenizer(torch.nn.Module):
36
+ def __init__(
37
+ self,
38
+ checkpoint: str = None,
39
+ checkpoint_enc: str = None,
40
+ checkpoint_dec: str = None,
41
+ tokenizer_config: dict[str, Any] = None,
42
+ device: str = "cuda",
43
+ dtype: str = "bfloat16",
44
+ ) -> None:
45
+ super().__init__()
46
+ self._device = device
47
+ self._dtype = getattr(torch, dtype)
48
+ self._full_model = (
49
+ load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None
50
+ )
51
+ self._enc_model = (
52
+ load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
53
+ if checkpoint_enc is not None
54
+ else None
55
+ )
56
+ self._dec_model = (
57
+ load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
58
+ if checkpoint_dec is not None
59
+ else None
60
+ )
61
+
62
+ @torch.no_grad()
63
+ def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
64
+ """Reconstrcuts a batch of video tensors after embedding into a latent.
65
+
66
+ Args:
67
+ video: The input video Bx3xTxHxW layout, range [-1..1].
68
+ Returns:
69
+ The reconstructed video, layout Bx3xTxHxW, range [-1..1].
70
+ """
71
+ if self._full_model is not None:
72
+ output_tensor = self._full_model(input_tensor)
73
+ output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
74
+ else:
75
+ output_latent = self.encode(input_tensor)[0]
76
+ output_tensor = self.decode(output_latent)
77
+ return output_tensor
78
+
79
+ @torch.no_grad()
80
+ def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
81
+ """Encodes a numpy video into a CausalVideo latent or code.
82
+
83
+ Args:
84
+ input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1].
85
+ Returns:
86
+ For causal continuous video (CV) tokenizer, the tuple contains:
87
+ - The latent embedding, Bx16x(t)x(h)x(w), where the compression
88
+ rate is (T/t x H/h x W/w), and channel dimension of 16.
89
+ For causal discrete video (DV) tokenizer, the tuple contains:
90
+ 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which
91
+ is formed by FSQ levels of (8,8,8,5,5,5).
92
+ 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate
93
+ is again (T/t x H/h x W/w), and channel dimension of 6.
94
+ """
95
+ assert input_tensor.ndim == 5, "input video should be of 5D."
96
+
97
+ output_latent = self._enc_model(input_tensor)
98
+ if isinstance(output_latent, torch.Tensor):
99
+ return output_latent
100
+ return output_latent[:-1]
101
+
102
+ @torch.no_grad()
103
+ def decode(self, input_latent: torch.Tensor) -> torch.Tensor:
104
+ """Encodes a numpy video into a CausalVideo latent.
105
+
106
+ Args:
107
+ input_latent: The continuous latent Bx16xtxhxw for CV,
108
+ or the discrete indices Bxtxhxw for DV.
109
+ Returns:
110
+ The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1].
111
+ """
112
+ assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete."
113
+ return self._dec_model(input_latent)
114
+
115
+ def forward(
116
+ self,
117
+ video: np.ndarray,
118
+ temporal_window: int = 17,
119
+ ) -> np.ndarray:
120
+ """Reconstructs video using a pre-trained CausalTokenizer autoencoder.
121
+ Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer
122
+ in a sliding manner with a `temporal_window` size.
123
+
124
+ Args:
125
+ video: The input video BxTxHxWx3 layout, range [0..255].
126
+ temporal_window: The length of the temporal window to process, default=25.
127
+ Returns:
128
+ The reconstructed video in range [0..255], layout BxTxHxWx3.
129
+ """
130
+ assert video.ndim == 5, "input video should be of 5D."
131
+ num_frames = video.shape[1] # can be of any length.
132
+ output_video_list = []
133
+ for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)):
134
+ # Input video for the current window.
135
+ start, end = idx * temporal_window, (idx + 1) * temporal_window
136
+ input_video = video[:, start:end, ...]
137
+
138
+ # Spatio-temporally pad input_video so it's evenly divisible.
139
+ padded_input_video, crop_region = pad_video_batch(input_video)
140
+ input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device)
141
+ output_tensor = self.autoencode(input_tensor)
142
+ padded_output_video = tensor2numpy(output_tensor)
143
+ output_video = unpad_video_batch(padded_output_video, crop_region)
144
+
145
+ output_video_list.append(output_video)
146
+ return np.concatenate(output_video_list, axis=1)
cosmos_transfer1/auxiliary/tokenizer/modules/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+
18
+ from cosmos_transfer1.auxiliary.tokenizer.modules.distributions import GaussianDistribution, IdentityDistribution
19
+ from cosmos_transfer1.auxiliary.tokenizer.modules.layers2d import Decoder, Encoder
20
+ from cosmos_transfer1.auxiliary.tokenizer.modules.layers3d import (
21
+ DecoderBase,
22
+ DecoderFactorized,
23
+ EncoderBase,
24
+ EncoderFactorized,
25
+ )
26
+ from cosmos_transfer1.auxiliary.tokenizer.modules.quantizers import (
27
+ FSQuantizer,
28
+ LFQuantizer,
29
+ ResidualFSQuantizer,
30
+ VectorQuantizer,
31
+ )
32
+
33
+
34
+ class EncoderType(Enum):
35
+ Default = Encoder
36
+
37
+
38
+ class DecoderType(Enum):
39
+ Default = Decoder
40
+
41
+
42
+ class Encoder3DType(Enum):
43
+ BASE = EncoderBase
44
+ FACTORIZED = EncoderFactorized
45
+
46
+
47
+ class Decoder3DType(Enum):
48
+ BASE = DecoderBase
49
+ FACTORIZED = DecoderFactorized
50
+
51
+
52
+ class ContinuousFormulation(Enum):
53
+ VAE = GaussianDistribution
54
+ AE = IdentityDistribution
55
+
56
+
57
+ class DiscreteQuantizer(Enum):
58
+ VQ = VectorQuantizer
59
+ LFQ = LFQuantizer
60
+ FSQ = FSQuantizer
61
+ RESFSQ = ResidualFSQuantizer
cosmos_transfer1/auxiliary/tokenizer/modules/distributions.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The distribution modes to use for continuous image tokenizers."""
17
+
18
+ import torch
19
+
20
+
21
+ class IdentityDistribution(torch.nn.Module):
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ def forward(self, parameters):
26
+ return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
27
+
28
+
29
+ class GaussianDistribution(torch.nn.Module):
30
+ def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
31
+ super().__init__()
32
+ self.min_logvar = min_logvar
33
+ self.max_logvar = max_logvar
34
+
35
+ def sample(self, mean, logvar):
36
+ std = torch.exp(0.5 * logvar)
37
+ return mean + std * torch.randn_like(mean)
38
+
39
+ def forward(self, parameters):
40
+ mean, logvar = torch.chunk(parameters, 2, dim=1)
41
+ logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
42
+ return self.sample(mean, logvar), (mean, logvar)
cosmos_transfer1/auxiliary/tokenizer/modules/layers2d.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The model definition for Continuous 2D layers
17
+
18
+ Adapted from: https://github.com/CompVis/stable-diffusion/blob/
19
+ 21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
20
+
21
+ [Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors]
22
+ https://github.com/CompVis/stable-diffusion/blob/
23
+ 21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/LICENSE
24
+ """
25
+
26
+ import math
27
+
28
+ import numpy as np
29
+
30
+ # pytorch_diffusion + derived encoder decoder
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from loguru import logger as logging
35
+
36
+ from cosmos_transfer1.auxiliary.tokenizer.modules.patching import Patcher, UnPatcher
37
+ from cosmos_transfer1.auxiliary.tokenizer.modules.utils import Normalize, nonlinearity
38
+
39
+
40
+ class Upsample(nn.Module):
41
+ def __init__(self, in_channels: int):
42
+ super().__init__()
43
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3)
47
+ return self.conv(x)
48
+
49
+
50
+ class Downsample(nn.Module):
51
+ def __init__(self, in_channels: int):
52
+ super().__init__()
53
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ pad = (0, 1, 0, 1)
57
+ x = F.pad(x, pad, mode="constant", value=0)
58
+ return self.conv(x)
59
+
60
+
61
+ class ResnetBlock(nn.Module):
62
+ def __init__(
63
+ self,
64
+ *,
65
+ in_channels: int,
66
+ out_channels: int = None,
67
+ dropout: float,
68
+ **kwargs,
69
+ ):
70
+ super().__init__()
71
+ self.in_channels = in_channels
72
+ out_channels = in_channels if out_channels is None else out_channels
73
+
74
+ self.norm1 = Normalize(in_channels)
75
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
76
+ self.norm2 = Normalize(out_channels)
77
+ self.dropout = nn.Dropout(dropout)
78
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+ self.nin_shortcut = (
80
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
81
+ if in_channels != out_channels
82
+ else nn.Identity()
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ h = x
87
+ h = self.norm1(h)
88
+ h = nonlinearity(h)
89
+ h = self.conv1(h)
90
+
91
+ h = self.norm2(h)
92
+ h = nonlinearity(h)
93
+ h = self.dropout(h)
94
+ h = self.conv2(h)
95
+
96
+ x = self.nin_shortcut(x)
97
+
98
+ return x + h
99
+
100
+
101
+ class AttnBlock(nn.Module):
102
+ def __init__(self, in_channels: int):
103
+ super().__init__()
104
+
105
+ self.norm = Normalize(in_channels)
106
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
107
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
108
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
109
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ # TODO (freda): Consider reusing implementations in Attn `imaginaire`,
113
+ # since than one is gonna be based on TransformerEngine's attn op,
114
+ # w/c could ease CP implementations.
115
+ h_ = x
116
+ h_ = self.norm(h_)
117
+ q = self.q(h_)
118
+ k = self.k(h_)
119
+ v = self.v(h_)
120
+
121
+ # compute attention
122
+ b, c, h, w = q.shape
123
+ q = q.reshape(b, c, h * w)
124
+ q = q.permute(0, 2, 1)
125
+ k = k.reshape(b, c, h * w)
126
+ w_ = torch.bmm(q, k)
127
+ w_ = w_ * (int(c) ** (-0.5))
128
+ w_ = F.softmax(w_, dim=2)
129
+
130
+ # attend to values
131
+ v = v.reshape(b, c, h * w)
132
+ w_ = w_.permute(0, 2, 1)
133
+ h_ = torch.bmm(v, w_)
134
+ h_ = h_.reshape(b, c, h, w)
135
+
136
+ h_ = self.proj_out(h_)
137
+
138
+ return x + h_
139
+
140
+
141
+ class Encoder(nn.Module):
142
+ def __init__(
143
+ self,
144
+ in_channels: int,
145
+ channels: int,
146
+ channels_mult: list[int],
147
+ num_res_blocks: int,
148
+ attn_resolutions: list[int],
149
+ dropout: float,
150
+ resolution: int,
151
+ z_channels: int,
152
+ spatial_compression: int,
153
+ **ignore_kwargs,
154
+ ):
155
+ super().__init__()
156
+ self.num_resolutions = len(channels_mult)
157
+ self.num_res_blocks = num_res_blocks
158
+
159
+ # Patcher.
160
+ patch_size = ignore_kwargs.get("patch_size", 1)
161
+ self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
162
+ in_channels = in_channels * patch_size * patch_size
163
+
164
+ # calculate the number of downsample operations
165
+ self.num_downsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
166
+ assert (
167
+ self.num_downsamples <= self.num_resolutions
168
+ ), f"we can only downsample {self.num_resolutions} times at most"
169
+
170
+ # downsampling
171
+ self.conv_in = torch.nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1)
172
+
173
+ curr_res = resolution // patch_size
174
+ in_ch_mult = (1,) + tuple(channels_mult)
175
+ self.in_ch_mult = in_ch_mult
176
+ self.down = nn.ModuleList()
177
+ for i_level in range(self.num_resolutions):
178
+ block = nn.ModuleList()
179
+ attn = nn.ModuleList()
180
+ block_in = channels * in_ch_mult[i_level]
181
+ block_out = channels * channels_mult[i_level]
182
+ for _ in range(self.num_res_blocks):
183
+ block.append(
184
+ ResnetBlock(
185
+ in_channels=block_in,
186
+ out_channels=block_out,
187
+ dropout=dropout,
188
+ )
189
+ )
190
+ block_in = block_out
191
+ if curr_res in attn_resolutions:
192
+ attn.append(AttnBlock(block_in))
193
+ down = nn.Module()
194
+ down.block = block
195
+ down.attn = attn
196
+ if i_level < self.num_downsamples:
197
+ down.downsample = Downsample(block_in)
198
+ curr_res = curr_res // 2
199
+ self.down.append(down)
200
+
201
+ # middle
202
+ self.mid = nn.Module()
203
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
204
+ self.mid.attn_1 = AttnBlock(block_in)
205
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
206
+
207
+ # end
208
+ self.norm_out = Normalize(block_in)
209
+ self.conv_out = torch.nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
210
+
211
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
212
+ x = self.patcher(x)
213
+
214
+ # downsampling
215
+ hs = [self.conv_in(x)]
216
+ for i_level in range(self.num_resolutions):
217
+ for i_block in range(self.num_res_blocks):
218
+ h = self.down[i_level].block[i_block](hs[-1])
219
+ if len(self.down[i_level].attn) > 0:
220
+ h = self.down[i_level].attn[i_block](h)
221
+ hs.append(h)
222
+ if i_level < self.num_downsamples:
223
+ hs.append(self.down[i_level].downsample(hs[-1]))
224
+
225
+ # middle
226
+ h = hs[-1]
227
+ h = self.mid.block_1(h)
228
+ h = self.mid.attn_1(h)
229
+ h = self.mid.block_2(h)
230
+
231
+ # end
232
+ h = self.norm_out(h)
233
+ h = nonlinearity(h)
234
+ h = self.conv_out(h)
235
+ return h
236
+
237
+
238
+ class Decoder(nn.Module):
239
+ def __init__(
240
+ self,
241
+ out_channels: int,
242
+ channels: int,
243
+ channels_mult: list[int],
244
+ num_res_blocks: int,
245
+ attn_resolutions: int,
246
+ dropout: float,
247
+ resolution: int,
248
+ z_channels: int,
249
+ spatial_compression: int,
250
+ **ignore_kwargs,
251
+ ):
252
+ super().__init__()
253
+ self.num_resolutions = len(channels_mult)
254
+ self.num_res_blocks = num_res_blocks
255
+
256
+ # UnPatcher.
257
+ patch_size = ignore_kwargs.get("patch_size", 1)
258
+ self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
259
+ out_ch = out_channels * patch_size * patch_size
260
+
261
+ # calculate the number of upsample operations
262
+ self.num_upsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
263
+ assert self.num_upsamples <= self.num_resolutions, f"we can only upsample {self.num_resolutions} times at most"
264
+
265
+ block_in = channels * channels_mult[self.num_resolutions - 1]
266
+ curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
267
+ self.z_shape = (1, z_channels, curr_res, curr_res)
268
+ logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
269
+
270
+ # z to block_in
271
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
272
+
273
+ # middle
274
+ self.mid = nn.Module()
275
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
276
+ self.mid.attn_1 = AttnBlock(block_in)
277
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
278
+
279
+ # upsampling
280
+ self.up = nn.ModuleList()
281
+ for i_level in reversed(range(self.num_resolutions)):
282
+ block = nn.ModuleList()
283
+ attn = nn.ModuleList()
284
+ block_out = channels * channels_mult[i_level]
285
+ for _ in range(self.num_res_blocks + 1):
286
+ block.append(
287
+ ResnetBlock(
288
+ in_channels=block_in,
289
+ out_channels=block_out,
290
+ dropout=dropout,
291
+ )
292
+ )
293
+ block_in = block_out
294
+ if curr_res in attn_resolutions:
295
+ attn.append(AttnBlock(block_in))
296
+ up = nn.Module()
297
+ up.block = block
298
+ up.attn = attn
299
+ if i_level >= (self.num_resolutions - self.num_upsamples):
300
+ up.upsample = Upsample(block_in)
301
+ curr_res = curr_res * 2
302
+ self.up.insert(0, up)
303
+
304
+ # end
305
+ self.norm_out = Normalize(block_in)
306
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
307
+
308
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
309
+ h = self.conv_in(z)
310
+
311
+ # middle
312
+ h = self.mid.block_1(h)
313
+ h = self.mid.attn_1(h)
314
+ h = self.mid.block_2(h)
315
+
316
+ # upsampling
317
+ for i_level in reversed(range(self.num_resolutions)):
318
+ for i_block in range(self.num_res_blocks + 1):
319
+ h = self.up[i_level].block[i_block](h)
320
+ if len(self.up[i_level].attn) > 0:
321
+ h = self.up[i_level].attn[i_block](h)
322
+ if i_level >= (self.num_resolutions - self.num_upsamples):
323
+ h = self.up[i_level].upsample(h)
324
+
325
+ h = self.norm_out(h)
326
+ h = nonlinearity(h)
327
+ h = self.conv_out(h)
328
+ h = self.unpatcher(h)
329
+ return h
cosmos_transfer1/auxiliary/tokenizer/modules/layers3d.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The model definition for 3D layers
17
+
18
+ Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/
19
+ 9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889
20
+
21
+ [MIT License Copyright (c) 2023 Phil Wang]
22
+ https://github.com/lucidrains/magvit2-pytorch/blob/
23
+ 9f49074179c912736e617d61b32be367eb5f993a/LICENSE
24
+ """
25
+ import math
26
+ from typing import Tuple, Union
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from loguru import logger as logging
33
+
34
+ from cosmos_transfer1.auxiliary.tokenizer.modules.patching import Patcher, Patcher3D, UnPatcher, UnPatcher3D
35
+ from cosmos_transfer1.auxiliary.tokenizer.modules.utils import (
36
+ CausalNormalize,
37
+ batch2space,
38
+ batch2time,
39
+ cast_tuple,
40
+ is_odd,
41
+ nonlinearity,
42
+ replication_pad,
43
+ space2batch,
44
+ time2batch,
45
+ )
46
+
47
+ _LEGACY_NUM_GROUPS = 32
48
+
49
+
50
+ class CausalConv3d(nn.Module):
51
+ def __init__(
52
+ self,
53
+ chan_in: int = 1,
54
+ chan_out: int = 1,
55
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
56
+ pad_mode: str = "constant",
57
+ **kwargs,
58
+ ):
59
+ super().__init__()
60
+ kernel_size = cast_tuple(kernel_size, 3)
61
+
62
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
63
+
64
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
65
+
66
+ dilation = kwargs.pop("dilation", 1)
67
+ stride = kwargs.pop("stride", 1)
68
+ time_stride = kwargs.pop("time_stride", 1)
69
+ time_dilation = kwargs.pop("time_dilation", 1)
70
+ padding = kwargs.pop("padding", 1)
71
+
72
+ self.pad_mode = pad_mode
73
+ time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride)
74
+ self.time_pad = time_pad
75
+
76
+ self.spatial_pad = (padding, padding, padding, padding)
77
+
78
+ stride = (time_stride, stride, stride)
79
+ dilation = (time_dilation, dilation, dilation)
80
+ self.conv3d = nn.Conv3d(
81
+ chan_in,
82
+ chan_out,
83
+ kernel_size,
84
+ stride=stride,
85
+ dilation=dilation,
86
+ **kwargs,
87
+ )
88
+
89
+ def _replication_pad(self, x: torch.Tensor) -> torch.Tensor:
90
+ x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1)
91
+ x = torch.cat([x_prev, x], dim=2)
92
+ padding = self.spatial_pad + (0, 0)
93
+ return F.pad(x, padding, mode=self.pad_mode, value=0.0)
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ x = self._replication_pad(x)
97
+ return self.conv3d(x)
98
+
99
+
100
+ class CausalUpsample3d(nn.Module):
101
+ def __init__(self, in_channels: int) -> None:
102
+ super().__init__()
103
+ self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
107
+ time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
108
+ if isinstance(time_factor, torch.Tensor):
109
+ time_factor = time_factor.item()
110
+ x = x.repeat_interleave(int(time_factor), dim=2)
111
+ # TODO(freda): Check if this causes temporal inconsistency.
112
+ # Shoule reverse the order of the following two ops,
113
+ # better perf and better temporal smoothness.
114
+ x = self.conv(x)
115
+ return x[..., int(time_factor - 1) :, :, :]
116
+
117
+
118
+ class CausalDownsample3d(nn.Module):
119
+ def __init__(self, in_channels: int) -> None:
120
+ super().__init__()
121
+ self.conv = CausalConv3d(
122
+ in_channels,
123
+ in_channels,
124
+ kernel_size=3,
125
+ stride=2,
126
+ time_stride=2,
127
+ padding=0,
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ pad = (0, 1, 0, 1, 0, 0)
132
+ x = F.pad(x, pad, mode="constant", value=0)
133
+ x = replication_pad(x)
134
+ x = self.conv(x)
135
+ return x
136
+
137
+
138
+ class CausalHybridUpsample3d(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_channels: int,
142
+ spatial_up: bool = True,
143
+ temporal_up: bool = True,
144
+ **kwargs,
145
+ ) -> None:
146
+ super().__init__()
147
+ self.conv1 = CausalConv3d(
148
+ in_channels,
149
+ in_channels,
150
+ kernel_size=(3, 1, 1),
151
+ stride=1,
152
+ time_stride=1,
153
+ padding=0,
154
+ )
155
+ self.conv2 = CausalConv3d(
156
+ in_channels,
157
+ in_channels,
158
+ kernel_size=(1, 3, 3),
159
+ stride=1,
160
+ time_stride=1,
161
+ padding=1,
162
+ )
163
+ self.conv3 = CausalConv3d(
164
+ in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ time_stride=1,
169
+ padding=0,
170
+ )
171
+ self.spatial_up = spatial_up
172
+ self.temporal_up = temporal_up
173
+
174
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
175
+ if not self.spatial_up and not self.temporal_up:
176
+ return x
177
+
178
+ # hybrid upsample temporally.
179
+ if self.temporal_up:
180
+ time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
181
+ if isinstance(time_factor, torch.Tensor):
182
+ time_factor = time_factor.item()
183
+ x = x.repeat_interleave(int(time_factor), dim=2)
184
+ x = x[..., int(time_factor - 1) :, :, :]
185
+ x = self.conv1(x) + x
186
+
187
+ # hybrid upsample spatially.
188
+ if self.spatial_up:
189
+ x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
190
+ x = self.conv2(x) + x
191
+
192
+ # final 1x1x1 conv.
193
+ x = self.conv3(x)
194
+ return x
195
+
196
+
197
+ class CausalHybridDownsample3d(nn.Module):
198
+ def __init__(
199
+ self,
200
+ in_channels: int,
201
+ spatial_down: bool = True,
202
+ temporal_down: bool = True,
203
+ **kwargs,
204
+ ) -> None:
205
+ super().__init__()
206
+ self.conv1 = CausalConv3d(
207
+ in_channels,
208
+ in_channels,
209
+ kernel_size=(1, 3, 3),
210
+ stride=2,
211
+ time_stride=1,
212
+ padding=0,
213
+ )
214
+ self.conv2 = CausalConv3d(
215
+ in_channels,
216
+ in_channels,
217
+ kernel_size=(3, 1, 1),
218
+ stride=1,
219
+ time_stride=2,
220
+ padding=0,
221
+ )
222
+ self.conv3 = CausalConv3d(
223
+ in_channels,
224
+ in_channels,
225
+ kernel_size=1,
226
+ stride=1,
227
+ time_stride=1,
228
+ padding=0,
229
+ )
230
+ self.spatial_down = spatial_down
231
+ self.temporal_down = temporal_down
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ if not self.spatial_down and not self.temporal_down:
235
+ return x
236
+
237
+ # hybrid downsample spatially.
238
+ if self.spatial_down:
239
+ pad = (0, 1, 0, 1, 0, 0)
240
+ x = F.pad(x, pad, mode="constant", value=0)
241
+ x1 = self.conv1(x)
242
+ x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2))
243
+ x = x1 + x2
244
+
245
+ # hybrid downsample temporally.
246
+ if self.temporal_down:
247
+ x = replication_pad(x)
248
+ x1 = self.conv2(x)
249
+ x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1))
250
+ x = x1 + x2
251
+
252
+ # final 1x1x1 conv.
253
+ x = self.conv3(x)
254
+ return x
255
+
256
+
257
+ class CausalResnetBlock3d(nn.Module):
258
+ def __init__(
259
+ self,
260
+ *,
261
+ in_channels: int,
262
+ out_channels: int = None,
263
+ dropout: float,
264
+ num_groups: int,
265
+ ) -> None:
266
+ super().__init__()
267
+ self.in_channels = in_channels
268
+ out_channels = in_channels if out_channels is None else out_channels
269
+
270
+ self.norm1 = CausalNormalize(in_channels, num_groups=num_groups)
271
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
272
+ self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
273
+ self.dropout = torch.nn.Dropout(dropout)
274
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
275
+ self.nin_shortcut = (
276
+ CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
277
+ if in_channels != out_channels
278
+ else nn.Identity()
279
+ )
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ h = x
283
+ h = self.norm1(h)
284
+ h = nonlinearity(h)
285
+ h = self.conv1(h)
286
+
287
+ h = self.norm2(h)
288
+ h = nonlinearity(h)
289
+ h = self.dropout(h)
290
+ h = self.conv2(h)
291
+ x = self.nin_shortcut(x)
292
+
293
+ return x + h
294
+
295
+
296
+ class CausalResnetBlockFactorized3d(nn.Module):
297
+ def __init__(
298
+ self,
299
+ *,
300
+ in_channels: int,
301
+ out_channels: int = None,
302
+ dropout: float,
303
+ num_groups: int,
304
+ ) -> None:
305
+ super().__init__()
306
+ self.in_channels = in_channels
307
+ out_channels = in_channels if out_channels is None else out_channels
308
+
309
+ self.norm1 = CausalNormalize(in_channels, num_groups=1)
310
+ self.conv1 = nn.Sequential(
311
+ CausalConv3d(
312
+ in_channels,
313
+ out_channels,
314
+ kernel_size=(1, 3, 3),
315
+ stride=1,
316
+ padding=1,
317
+ ),
318
+ CausalConv3d(
319
+ out_channels,
320
+ out_channels,
321
+ kernel_size=(3, 1, 1),
322
+ stride=1,
323
+ padding=0,
324
+ ),
325
+ )
326
+ self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
327
+ self.dropout = torch.nn.Dropout(dropout)
328
+ self.conv2 = nn.Sequential(
329
+ CausalConv3d(
330
+ out_channels,
331
+ out_channels,
332
+ kernel_size=(1, 3, 3),
333
+ stride=1,
334
+ padding=1,
335
+ ),
336
+ CausalConv3d(
337
+ out_channels,
338
+ out_channels,
339
+ kernel_size=(3, 1, 1),
340
+ stride=1,
341
+ padding=0,
342
+ ),
343
+ )
344
+ self.nin_shortcut = (
345
+ CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
346
+ if in_channels != out_channels
347
+ else nn.Identity()
348
+ )
349
+
350
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
351
+ h = x
352
+ h = self.norm1(h)
353
+ h = nonlinearity(h)
354
+ h = self.conv1(h)
355
+
356
+ h = self.norm2(h)
357
+ h = nonlinearity(h)
358
+ h = self.dropout(h)
359
+ h = self.conv2(h)
360
+ x = self.nin_shortcut(x)
361
+
362
+ return x + h
363
+
364
+
365
+ class CausalAttnBlock(nn.Module):
366
+ def __init__(self, in_channels: int, num_groups: int) -> None:
367
+ super().__init__()
368
+
369
+ self.norm = CausalNormalize(in_channels, num_groups=num_groups)
370
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
371
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
372
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
373
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
374
+
375
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
376
+ h_ = x
377
+ h_ = self.norm(h_)
378
+ q = self.q(h_)
379
+ k = self.k(h_)
380
+ v = self.v(h_)
381
+
382
+ # compute attention
383
+ q, batch_size = time2batch(q)
384
+ k, batch_size = time2batch(k)
385
+ v, batch_size = time2batch(v)
386
+
387
+ b, c, h, w = q.shape
388
+ q = q.reshape(b, c, h * w)
389
+ q = q.permute(0, 2, 1)
390
+ k = k.reshape(b, c, h * w)
391
+ w_ = torch.bmm(q, k)
392
+ w_ = w_ * (int(c) ** (-0.5))
393
+ w_ = F.softmax(w_, dim=2)
394
+
395
+ # attend to values
396
+ v = v.reshape(b, c, h * w)
397
+ w_ = w_.permute(0, 2, 1)
398
+ h_ = torch.bmm(v, w_)
399
+ h_ = h_.reshape(b, c, h, w)
400
+
401
+ h_ = batch2time(h_, batch_size)
402
+ h_ = self.proj_out(h_)
403
+ return x + h_
404
+
405
+
406
+ class CausalTemporalAttnBlock(nn.Module):
407
+ def __init__(self, in_channels: int, num_groups: int) -> None:
408
+ super().__init__()
409
+
410
+ self.norm = CausalNormalize(in_channels, num_groups=num_groups)
411
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
412
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
413
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
414
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
415
+
416
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
417
+ h_ = x
418
+ h_ = self.norm(h_)
419
+ q = self.q(h_)
420
+ k = self.k(h_)
421
+ v = self.v(h_)
422
+
423
+ # compute attention
424
+ q, batch_size, height = space2batch(q)
425
+ k, _, _ = space2batch(k)
426
+ v, _, _ = space2batch(v)
427
+
428
+ bhw, c, t = q.shape
429
+ q = q.permute(0, 2, 1) # (bhw, t, c)
430
+ k = k.permute(0, 2, 1) # (bhw, t, c)
431
+ v = v.permute(0, 2, 1) # (bhw, t, c)
432
+
433
+ w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t)
434
+ w_ = w_ * (int(c) ** (-0.5))
435
+
436
+ # Apply causal mask
437
+ mask = torch.tril(torch.ones_like(w_))
438
+ w_ = w_.masked_fill(mask == 0, float("-inf"))
439
+ w_ = F.softmax(w_, dim=2)
440
+
441
+ # attend to values
442
+ h_ = torch.bmm(w_, v) # (bhw, t, c)
443
+ h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t)
444
+
445
+ h_ = batch2space(h_, batch_size, height)
446
+ h_ = self.proj_out(h_)
447
+ return x + h_
448
+
449
+
450
+ class EncoderBase(nn.Module):
451
+ def __init__(
452
+ self,
453
+ in_channels: int,
454
+ channels: int,
455
+ channels_mult: list[int],
456
+ num_res_blocks: int,
457
+ attn_resolutions: list[int],
458
+ dropout: float,
459
+ resolution: int,
460
+ z_channels: int,
461
+ **ignore_kwargs,
462
+ ) -> None:
463
+ super().__init__()
464
+ self.num_resolutions = len(channels_mult)
465
+ self.num_res_blocks = num_res_blocks
466
+
467
+ # Patcher.
468
+ patch_size = ignore_kwargs.get("patch_size", 1)
469
+ self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
470
+ in_channels = in_channels * patch_size * patch_size
471
+
472
+ # downsampling
473
+ self.conv_in = CausalConv3d(in_channels, channels, kernel_size=3, stride=1, padding=1)
474
+
475
+ # num of groups for GroupNorm, num_groups=1 for LayerNorm.
476
+ num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS)
477
+ curr_res = resolution // patch_size
478
+ in_ch_mult = (1,) + tuple(channels_mult)
479
+ self.in_ch_mult = in_ch_mult
480
+ self.down = nn.ModuleList()
481
+ for i_level in range(self.num_resolutions):
482
+ block = nn.ModuleList()
483
+ attn = nn.ModuleList()
484
+ block_in = channels * in_ch_mult[i_level]
485
+ block_out = channels * channels_mult[i_level]
486
+ for _ in range(self.num_res_blocks):
487
+ block.append(
488
+ CausalResnetBlock3d(
489
+ in_channels=block_in,
490
+ out_channels=block_out,
491
+ dropout=dropout,
492
+ num_groups=num_groups,
493
+ )
494
+ )
495
+ block_in = block_out
496
+ if curr_res in attn_resolutions:
497
+ attn.append(CausalAttnBlock(block_in, num_groups=num_groups))
498
+ down = nn.Module()
499
+ down.block = block
500
+ down.attn = attn
501
+ if i_level != self.num_resolutions - 1:
502
+ down.downsample = CausalDownsample3d(block_in)
503
+ curr_res = curr_res // 2
504
+ self.down.append(down)
505
+
506
+ # middle
507
+ self.mid = nn.Module()
508
+ self.mid.block_1 = CausalResnetBlock3d(
509
+ in_channels=block_in,
510
+ out_channels=block_in,
511
+ dropout=dropout,
512
+ num_groups=num_groups,
513
+ )
514
+ self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups)
515
+ self.mid.block_2 = CausalResnetBlock3d(
516
+ in_channels=block_in,
517
+ out_channels=block_in,
518
+ dropout=dropout,
519
+ num_groups=num_groups,
520
+ )
521
+
522
+ # end
523
+ self.norm_out = CausalNormalize(block_in, num_groups=num_groups)
524
+ self.conv_out = CausalConv3d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
525
+
526
+ def patcher3d(self, x: torch.Tensor) -> torch.Tensor:
527
+ x, batch_size = time2batch(x)
528
+ x = self.patcher(x)
529
+ x = batch2time(x, batch_size)
530
+ return x
531
+
532
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
533
+ x = self.patcher3d(x)
534
+
535
+ # downsampling
536
+ hs = [self.conv_in(x)]
537
+ for i_level in range(self.num_resolutions):
538
+ for i_block in range(self.num_res_blocks):
539
+ h = self.down[i_level].block[i_block](hs[-1])
540
+ if len(self.down[i_level].attn) > 0:
541
+ h = self.down[i_level].attn[i_block](h)
542
+ hs.append(h)
543
+ if i_level != self.num_resolutions - 1:
544
+ hs.append(self.down[i_level].downsample(hs[-1]))
545
+ else:
546
+ # temporal downsample (last level)
547
+ time_factor = 1 + 1 * (hs[-1].shape[2] > 1)
548
+ if isinstance(time_factor, torch.Tensor):
549
+ time_factor = time_factor.item()
550
+ hs[-1] = replication_pad(hs[-1])
551
+ hs.append(
552
+ F.avg_pool3d(
553
+ hs[-1],
554
+ kernel_size=[time_factor, 1, 1],
555
+ stride=[2, 1, 1],
556
+ )
557
+ )
558
+
559
+ # middle
560
+ h = hs[-1]
561
+ h = self.mid.block_1(h)
562
+ h = self.mid.attn_1(h)
563
+ h = self.mid.block_2(h)
564
+
565
+ # end
566
+ h = self.norm_out(h)
567
+ h = nonlinearity(h)
568
+ h = self.conv_out(h)
569
+ return h
570
+
571
+
572
+ class DecoderBase(nn.Module):
573
+ def __init__(
574
+ self,
575
+ out_channels: int,
576
+ channels: int,
577
+ channels_mult: list[int],
578
+ num_res_blocks: int,
579
+ attn_resolutions: list[int],
580
+ dropout: float,
581
+ resolution: int,
582
+ z_channels: int,
583
+ **ignore_kwargs,
584
+ ):
585
+ super().__init__()
586
+ self.num_resolutions = len(channels_mult)
587
+ self.num_res_blocks = num_res_blocks
588
+
589
+ # UnPatcher.
590
+ patch_size = ignore_kwargs.get("patch_size", 1)
591
+ self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
592
+ out_ch = out_channels * patch_size * patch_size
593
+
594
+ block_in = channels * channels_mult[self.num_resolutions - 1]
595
+ curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
596
+ self.z_shape = (1, z_channels, curr_res, curr_res)
597
+ logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
598
+
599
+ # z to block_in
600
+ self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
601
+
602
+ # num of groups for GroupNorm, num_groups=1 for LayerNorm.
603
+ num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS)
604
+
605
+ # middle
606
+ self.mid = nn.Module()
607
+ self.mid.block_1 = CausalResnetBlock3d(
608
+ in_channels=block_in,
609
+ out_channels=block_in,
610
+ dropout=dropout,
611
+ num_groups=num_groups,
612
+ )
613
+ self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups)
614
+ self.mid.block_2 = CausalResnetBlock3d(
615
+ in_channels=block_in,
616
+ out_channels=block_in,
617
+ dropout=dropout,
618
+ num_groups=num_groups,
619
+ )
620
+
621
+ # upsampling
622
+ self.up = nn.ModuleList()
623
+ for i_level in reversed(range(self.num_resolutions)):
624
+ block = nn.ModuleList()
625
+ attn = nn.ModuleList()
626
+ block_out = channels * channels_mult[i_level]
627
+ for _ in range(self.num_res_blocks + 1):
628
+ block.append(
629
+ CausalResnetBlock3d(
630
+ in_channels=block_in,
631
+ out_channels=block_out,
632
+ dropout=dropout,
633
+ num_groups=num_groups,
634
+ )
635
+ )
636
+ block_in = block_out
637
+ if curr_res in attn_resolutions:
638
+ attn.append(CausalAttnBlock(block_in, num_groups=num_groups))
639
+ up = nn.Module()
640
+ up.block = block
641
+ up.attn = attn
642
+ if i_level != 0:
643
+ up.upsample = CausalUpsample3d(block_in)
644
+ curr_res = curr_res * 2
645
+ self.up.insert(0, up) # prepend to get consistent order
646
+
647
+ # end
648
+ self.norm_out = CausalNormalize(block_in, num_groups=num_groups)
649
+ self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
650
+
651
+ def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor:
652
+ x, batch_size = time2batch(x)
653
+ x = self.unpatcher(x)
654
+ x = batch2time(x, batch_size)
655
+
656
+ return x
657
+
658
+ def forward(self, z):
659
+ h = self.conv_in(z)
660
+
661
+ # middle block.
662
+ h = self.mid.block_1(h)
663
+ h = self.mid.attn_1(h)
664
+ h = self.mid.block_2(h)
665
+
666
+ # decoder blocks.
667
+ for i_level in reversed(range(self.num_resolutions)):
668
+ for i_block in range(self.num_res_blocks + 1):
669
+ h = self.up[i_level].block[i_block](h)
670
+ if len(self.up[i_level].attn) > 0:
671
+ h = self.up[i_level].attn[i_block](h)
672
+ if i_level != 0:
673
+ h = self.up[i_level].upsample(h)
674
+ else:
675
+ # temporal upsample (last level)
676
+ time_factor = 1.0 + 1.0 * (h.shape[2] > 1)
677
+ if isinstance(time_factor, torch.Tensor):
678
+ time_factor = time_factor.item()
679
+ h = h.repeat_interleave(int(time_factor), dim=2)
680
+ h = h[..., int(time_factor - 1) :, :, :]
681
+
682
+ h = self.norm_out(h)
683
+ h = nonlinearity(h)
684
+ h = self.conv_out(h)
685
+ h = self.unpatcher3d(h)
686
+ return h
687
+
688
+
689
+ class EncoderFactorized(nn.Module):
690
+ def __init__(
691
+ self,
692
+ in_channels: int,
693
+ channels: int,
694
+ channels_mult: list[int],
695
+ num_res_blocks: int,
696
+ attn_resolutions: list[int],
697
+ dropout: float,
698
+ resolution: int,
699
+ z_channels: int,
700
+ spatial_compression: int = 16,
701
+ temporal_compression: int = 8,
702
+ **ignore_kwargs,
703
+ ) -> None:
704
+ super().__init__()
705
+ self.num_resolutions = len(channels_mult)
706
+ self.num_res_blocks = num_res_blocks
707
+
708
+ # Patcher.
709
+ patch_size = ignore_kwargs.get("patch_size", 1)
710
+ self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
711
+ in_channels = in_channels * patch_size * patch_size * patch_size
712
+
713
+ # calculate the number of downsample operations
714
+ self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
715
+ assert (
716
+ self.num_spatial_downs <= self.num_resolutions
717
+ ), f"Spatially downsample {self.num_resolutions} times at most"
718
+
719
+ self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
720
+ assert (
721
+ self.num_temporal_downs <= self.num_resolutions
722
+ ), f"Temporally downsample {self.num_resolutions} times at most"
723
+
724
+ # downsampling
725
+ self.conv_in = nn.Sequential(
726
+ CausalConv3d(
727
+ in_channels,
728
+ channels,
729
+ kernel_size=(1, 3, 3),
730
+ stride=1,
731
+ padding=1,
732
+ ),
733
+ CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0),
734
+ )
735
+
736
+ curr_res = resolution // patch_size
737
+ in_ch_mult = (1,) + tuple(channels_mult)
738
+ self.in_ch_mult = in_ch_mult
739
+ self.down = nn.ModuleList()
740
+ for i_level in range(self.num_resolutions):
741
+ block = nn.ModuleList()
742
+ attn = nn.ModuleList()
743
+ block_in = channels * in_ch_mult[i_level]
744
+ block_out = channels * channels_mult[i_level]
745
+ for _ in range(self.num_res_blocks):
746
+ block.append(
747
+ CausalResnetBlockFactorized3d(
748
+ in_channels=block_in,
749
+ out_channels=block_out,
750
+ dropout=dropout,
751
+ num_groups=1,
752
+ )
753
+ )
754
+ block_in = block_out
755
+ if curr_res in attn_resolutions:
756
+ attn.append(
757
+ nn.Sequential(
758
+ CausalAttnBlock(block_in, num_groups=1),
759
+ CausalTemporalAttnBlock(block_in, num_groups=1),
760
+ )
761
+ )
762
+ down = nn.Module()
763
+ down.block = block
764
+ down.attn = attn
765
+ if i_level != self.num_resolutions - 1:
766
+ spatial_down = i_level < self.num_spatial_downs
767
+ temporal_down = i_level < self.num_temporal_downs
768
+ down.downsample = CausalHybridDownsample3d(
769
+ block_in,
770
+ spatial_down=spatial_down,
771
+ temporal_down=temporal_down,
772
+ )
773
+ curr_res = curr_res // 2
774
+ self.down.append(down)
775
+
776
+ # middle
777
+ self.mid = nn.Module()
778
+ self.mid.block_1 = CausalResnetBlockFactorized3d(
779
+ in_channels=block_in,
780
+ out_channels=block_in,
781
+ dropout=dropout,
782
+ num_groups=1,
783
+ )
784
+ self.mid.attn_1 = nn.Sequential(
785
+ CausalAttnBlock(block_in, num_groups=1),
786
+ CausalTemporalAttnBlock(block_in, num_groups=1),
787
+ )
788
+ self.mid.block_2 = CausalResnetBlockFactorized3d(
789
+ in_channels=block_in,
790
+ out_channels=block_in,
791
+ dropout=dropout,
792
+ num_groups=1,
793
+ )
794
+
795
+ # end
796
+ self.norm_out = CausalNormalize(block_in, num_groups=1)
797
+ self.conv_out = nn.Sequential(
798
+ CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
799
+ CausalConv3d(
800
+ z_channels,
801
+ z_channels,
802
+ kernel_size=(3, 1, 1),
803
+ stride=1,
804
+ padding=0,
805
+ ),
806
+ )
807
+
808
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
809
+ x = self.patcher3d(x)
810
+
811
+ # downsampling
812
+ hs = [self.conv_in(x)]
813
+ for i_level in range(self.num_resolutions):
814
+ for i_block in range(self.num_res_blocks):
815
+ h = self.down[i_level].block[i_block](hs[-1])
816
+ if len(self.down[i_level].attn) > 0:
817
+ h = self.down[i_level].attn[i_block](h)
818
+ hs.append(h)
819
+ if i_level != self.num_resolutions - 1:
820
+ hs.append(self.down[i_level].downsample(hs[-1]))
821
+
822
+ # middle
823
+ h = hs[-1]
824
+ h = self.mid.block_1(h)
825
+ h = self.mid.attn_1(h)
826
+ h = self.mid.block_2(h)
827
+
828
+ # end
829
+ h = self.norm_out(h)
830
+ h = nonlinearity(h)
831
+ h = self.conv_out(h)
832
+ return h
833
+
834
+
835
+ class DecoderFactorized(nn.Module):
836
+ def __init__(
837
+ self,
838
+ out_channels: int,
839
+ channels: int,
840
+ channels_mult: list[int],
841
+ num_res_blocks: int,
842
+ attn_resolutions: list[int],
843
+ dropout: float,
844
+ resolution: int,
845
+ z_channels: int,
846
+ spatial_compression: int = 16,
847
+ temporal_compression: int = 8,
848
+ **ignore_kwargs,
849
+ ):
850
+ super().__init__()
851
+ self.num_resolutions = len(channels_mult)
852
+ self.num_res_blocks = num_res_blocks
853
+
854
+ # UnPatcher.
855
+ patch_size = ignore_kwargs.get("patch_size", 1)
856
+ self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
857
+ out_ch = out_channels * patch_size * patch_size * patch_size
858
+
859
+ # calculate the number of upsample operations
860
+ self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
861
+ assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most"
862
+ self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
863
+ assert (
864
+ self.num_temporal_ups <= self.num_resolutions
865
+ ), f"Temporally upsample {self.num_resolutions} times at most"
866
+
867
+ block_in = channels * channels_mult[self.num_resolutions - 1]
868
+ curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
869
+ self.z_shape = (1, z_channels, curr_res, curr_res)
870
+ logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
871
+
872
+ # z to block_in
873
+ self.conv_in = nn.Sequential(
874
+ CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1),
875
+ CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0),
876
+ )
877
+
878
+ # middle
879
+ self.mid = nn.Module()
880
+ self.mid.block_1 = CausalResnetBlockFactorized3d(
881
+ in_channels=block_in,
882
+ out_channels=block_in,
883
+ dropout=dropout,
884
+ num_groups=1,
885
+ )
886
+ self.mid.attn_1 = nn.Sequential(
887
+ CausalAttnBlock(block_in, num_groups=1),
888
+ CausalTemporalAttnBlock(block_in, num_groups=1),
889
+ )
890
+ self.mid.block_2 = CausalResnetBlockFactorized3d(
891
+ in_channels=block_in,
892
+ out_channels=block_in,
893
+ dropout=dropout,
894
+ num_groups=1,
895
+ )
896
+
897
+ legacy_mode = ignore_kwargs.get("legacy_mode", False)
898
+ # upsampling
899
+ self.up = nn.ModuleList()
900
+ for i_level in reversed(range(self.num_resolutions)):
901
+ block = nn.ModuleList()
902
+ attn = nn.ModuleList()
903
+ block_out = channels * channels_mult[i_level]
904
+ for _ in range(self.num_res_blocks + 1):
905
+ block.append(
906
+ CausalResnetBlockFactorized3d(
907
+ in_channels=block_in,
908
+ out_channels=block_out,
909
+ dropout=dropout,
910
+ num_groups=1,
911
+ )
912
+ )
913
+ block_in = block_out
914
+ if curr_res in attn_resolutions:
915
+ attn.append(
916
+ nn.Sequential(
917
+ CausalAttnBlock(block_in, num_groups=1),
918
+ CausalTemporalAttnBlock(block_in, num_groups=1),
919
+ )
920
+ )
921
+ up = nn.Module()
922
+ up.block = block
923
+ up.attn = attn
924
+ if i_level != 0:
925
+ # The layer index for temporal/spatial downsampling performed
926
+ # in the encoder should correspond to the layer index in
927
+ # reverse order where upsampling is performed in the decoder.
928
+ # If you've a pre-trained model, you can simply finetune.
929
+ i_level_reverse = self.num_resolutions - i_level - 1
930
+ if legacy_mode:
931
+ temporal_up = i_level_reverse < self.num_temporal_ups
932
+ else:
933
+ temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1
934
+ spatial_up = temporal_up or (
935
+ i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups
936
+ )
937
+ up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up)
938
+ curr_res = curr_res * 2
939
+ self.up.insert(0, up) # prepend to get consistent order
940
+
941
+ # end
942
+ self.norm_out = CausalNormalize(block_in, num_groups=1)
943
+ self.conv_out = nn.Sequential(
944
+ CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1),
945
+ CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0),
946
+ )
947
+
948
+ def forward(self, z):
949
+ h = self.conv_in(z)
950
+
951
+ # middle block.
952
+ h = self.mid.block_1(h)
953
+ h = self.mid.attn_1(h)
954
+ h = self.mid.block_2(h)
955
+
956
+ # decoder blocks.
957
+ for i_level in reversed(range(self.num_resolutions)):
958
+ for i_block in range(self.num_res_blocks + 1):
959
+ h = self.up[i_level].block[i_block](h)
960
+ if len(self.up[i_level].attn) > 0:
961
+ h = self.up[i_level].attn[i_block](h)
962
+ if i_level != 0:
963
+ h = self.up[i_level].upsample(h)
964
+
965
+ h = self.norm_out(h)
966
+ h = nonlinearity(h)
967
+ h = self.conv_out(h)
968
+ h = self.unpatcher3d(h)
969
+ return h
cosmos_transfer1/auxiliary/tokenizer/modules/patching.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The patcher and unpatcher implementation for 2D and 3D data.
17
+
18
+ The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions.
19
+ One on the rows and one on the columns.
20
+ For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2.
21
+ We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component.
22
+ For H component, we can use a 1D convolution with kernel [1, -1] and stride 2.
23
+ Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all
24
+ as we need to support downsampling for more than 2x.
25
+ For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be.
26
+ [3, 256, 256] -> [12, 128, 128] -> [48, 64, 64]
27
+ """
28
+
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from einops import rearrange
32
+
33
+ _WAVELETS = {
34
+ "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
35
+ "rearrange": torch.tensor([1.0, 1.0]),
36
+ }
37
+ _PERSISTENT = False
38
+
39
+
40
+ class Patcher(torch.nn.Module):
41
+ """A module to convert image tensors into patches using torch operations.
42
+
43
+ The main difference from `class Patching` is that this module implements
44
+ all operations using torch, rather than python or numpy, for efficiency purpose.
45
+
46
+ It's bit-wise identical to the Patching module outputs, with the added
47
+ benefit of being torch.jit scriptable.
48
+ """
49
+
50
+ def __init__(self, patch_size=1, patch_method="haar"):
51
+ super().__init__()
52
+ self.patch_size = patch_size
53
+ self.patch_method = patch_method
54
+ self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
55
+ self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
56
+ self.register_buffer(
57
+ "_arange",
58
+ torch.arange(_WAVELETS[patch_method].shape[0]),
59
+ persistent=_PERSISTENT,
60
+ )
61
+ for param in self.parameters():
62
+ param.requires_grad = False
63
+
64
+ def forward(self, x):
65
+ if self.patch_method == "haar":
66
+ return self._haar(x)
67
+ elif self.patch_method == "rearrange":
68
+ return self._arrange(x)
69
+ else:
70
+ raise ValueError("Unknown patch method: " + self.patch_method)
71
+
72
+ def _dwt(self, x, mode="reflect", rescale=False):
73
+ dtype = x.dtype
74
+ h = self.wavelets
75
+
76
+ n = h.shape[0]
77
+ g = x.shape[1]
78
+ hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
79
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
80
+ hh = hh.to(dtype=dtype)
81
+ hl = hl.to(dtype=dtype)
82
+
83
+ x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
84
+ xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
85
+ xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
86
+ xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
87
+ xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
88
+ xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
89
+ xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
90
+
91
+ out = torch.cat([xll, xlh, xhl, xhh], dim=1)
92
+ if rescale:
93
+ out = out / 2
94
+ return out
95
+
96
+ def _haar(self, x):
97
+ for _ in self.range:
98
+ x = self._dwt(x, rescale=True)
99
+ return x
100
+
101
+ def _arrange(self, x):
102
+ x = rearrange(
103
+ x,
104
+ "b c (h p1) (w p2) -> b (c p1 p2) h w",
105
+ p1=self.patch_size,
106
+ p2=self.patch_size,
107
+ ).contiguous()
108
+ return x
109
+
110
+
111
+ class Patcher3D(Patcher):
112
+ """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
113
+
114
+ def __init__(self, patch_size=1, patch_method="haar"):
115
+ super().__init__(patch_method=patch_method, patch_size=patch_size)
116
+ self.register_buffer(
117
+ "patch_size_buffer",
118
+ patch_size * torch.ones([1], dtype=torch.int32),
119
+ persistent=_PERSISTENT,
120
+ )
121
+
122
+ def _dwt(self, x, wavelet, mode="reflect", rescale=False):
123
+ dtype = x.dtype
124
+ h = self.wavelets
125
+
126
+ n = h.shape[0]
127
+ g = x.shape[1]
128
+ hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
129
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
130
+ hh = hh.to(dtype=dtype)
131
+ hl = hl.to(dtype=dtype)
132
+
133
+ # Handles temporal axis.
134
+ x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
135
+ xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
136
+ xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
137
+
138
+ # Handles spatial axes.
139
+ xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
140
+ xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
141
+ xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
142
+ xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
143
+
144
+ xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
145
+ xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
146
+ xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
147
+ xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
148
+ xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
149
+ xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
150
+ xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
151
+ xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
152
+
153
+ out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
154
+ if rescale:
155
+ out = out / (2 * torch.sqrt(torch.tensor(2.0)))
156
+ return out
157
+
158
+ def _haar(self, x):
159
+ xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
160
+ x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
161
+ for _ in self.range:
162
+ x = self._dwt(x, "haar", rescale=True)
163
+ return x
164
+
165
+ def _arrange(self, x):
166
+ xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
167
+ x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
168
+ x = rearrange(
169
+ x,
170
+ "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
171
+ p1=self.patch_size,
172
+ p2=self.patch_size,
173
+ p3=self.patch_size,
174
+ ).contiguous()
175
+ return x
176
+
177
+
178
+ class UnPatcher(torch.nn.Module):
179
+ """A module to convert patches into image tensorsusing torch operations.
180
+
181
+ The main difference from `class Unpatching` is that this module implements
182
+ all operations using torch, rather than python or numpy, for efficiency purpose.
183
+
184
+ It's bit-wise identical to the Unpatching module outputs, with the added
185
+ benefit of being torch.jit scriptable.
186
+ """
187
+
188
+ def __init__(self, patch_size=1, patch_method="haar"):
189
+ super().__init__()
190
+ self.patch_size = patch_size
191
+ self.patch_method = patch_method
192
+ self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
193
+ self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
194
+ self.register_buffer(
195
+ "_arange",
196
+ torch.arange(_WAVELETS[patch_method].shape[0]),
197
+ persistent=_PERSISTENT,
198
+ )
199
+ for param in self.parameters():
200
+ param.requires_grad = False
201
+
202
+ def forward(self, x):
203
+ if self.patch_method == "haar":
204
+ return self._ihaar(x)
205
+ elif self.patch_method == "rearrange":
206
+ return self._iarrange(x)
207
+ else:
208
+ raise ValueError("Unknown patch method: " + self.patch_method)
209
+
210
+ def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
211
+ dtype = x.dtype
212
+ h = self.wavelets
213
+ n = h.shape[0]
214
+
215
+ g = x.shape[1] // 4
216
+ hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
217
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
218
+ hh = hh.to(dtype=dtype)
219
+ hl = hl.to(dtype=dtype)
220
+
221
+ xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
222
+
223
+ # Inverse transform.
224
+ yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
225
+ yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
226
+ yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
227
+ yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
228
+ y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
229
+ y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
230
+
231
+ if rescale:
232
+ y = y * 2
233
+ return y
234
+
235
+ def _ihaar(self, x):
236
+ for _ in self.range:
237
+ x = self._idwt(x, "haar", rescale=True)
238
+ return x
239
+
240
+ def _iarrange(self, x):
241
+ x = rearrange(
242
+ x,
243
+ "b (c p1 p2) h w -> b c (h p1) (w p2)",
244
+ p1=self.patch_size,
245
+ p2=self.patch_size,
246
+ )
247
+ return x
248
+
249
+
250
+ class UnPatcher3D(UnPatcher):
251
+ """A 3D inverse discrete wavelet transform for video wavelet decompositions."""
252
+
253
+ def __init__(self, patch_size=1, patch_method="haar"):
254
+ super().__init__(patch_method=patch_method, patch_size=patch_size)
255
+
256
+ def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False):
257
+ dtype = x.dtype
258
+ h = self.wavelets
259
+ n = h.shape[0]
260
+
261
+ g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
262
+ hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
263
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
264
+ hl = hl.to(dtype=dtype)
265
+ hh = hh.to(dtype=dtype)
266
+
267
+ xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
268
+
269
+ # Height height transposed convolutions.
270
+ xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
271
+ xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
272
+
273
+ xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
274
+ xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
275
+
276
+ xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
277
+ xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
278
+
279
+ xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
280
+ xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
281
+
282
+ # Handles width transposed convolutions.
283
+ xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
284
+ xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
285
+ xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
286
+ xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
287
+
288
+ # Handles time axis transposed convolutions.
289
+ x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
290
+ x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
291
+
292
+ if rescale:
293
+ x = x * (2 * torch.sqrt(torch.tensor(2.0)))
294
+ return x
295
+
296
+ def _ihaar(self, x):
297
+ for _ in self.range:
298
+ x = self._idwt(x, "haar", rescale=True)
299
+ x = x[:, :, self.patch_size - 1 :, ...]
300
+ return x
301
+
302
+ def _iarrange(self, x):
303
+ x = rearrange(
304
+ x,
305
+ "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
306
+ p1=self.patch_size,
307
+ p2=self.patch_size,
308
+ p3=self.patch_size,
309
+ )
310
+ x = x[:, :, self.patch_size - 1 :, ...]
311
+ return x
cosmos_transfer1/auxiliary/tokenizer/modules/quantizers.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Quantizers for discrete image and video tokenization."""
17
+
18
+ from typing import Optional
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from einops import reduce
25
+ from loguru import logger as logging
26
+
27
+ from cosmos_transfer1.auxiliary.tokenizer.modules.utils import (
28
+ default,
29
+ entropy,
30
+ pack_one,
31
+ rearrange,
32
+ round_ste,
33
+ unpack_one,
34
+ )
35
+
36
+
37
+ class ResidualFSQuantizer(nn.Module):
38
+ """Residual Finite Scalar Quantization
39
+
40
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
41
+ """
42
+
43
+ def __init__(self, levels: list[int], num_quantizers: int, **ignore_kwargs):
44
+ super().__init__()
45
+ self.dtype = ignore_kwargs.get("dtype", torch.float32)
46
+ self.layers = nn.ModuleList([FSQuantizer(levels=levels) for _ in range(num_quantizers)])
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ indices_stack = []
50
+ residual = x
51
+ quantized_out = 0
52
+ loss_out = 0
53
+ for i, layer in enumerate(self.layers):
54
+ quant_indices, z, loss = layer(residual)
55
+ indices_stack.append(quant_indices)
56
+ residual = residual - z.detach()
57
+ quantized_out = quantized_out + z
58
+ loss_out = loss_out + loss
59
+ self.residual = residual
60
+ indices = torch.stack(indices_stack, dim=1)
61
+ return indices, quantized_out.to(self.dtype), loss_out.to(self.dtype)
62
+
63
+ def indices_to_codes(self, indices_stack: torch.Tensor) -> torch.Tensor:
64
+ quantized_out = 0
65
+ for layer, indices in zip(self.layers, indices_stack.transpose(0, 1)):
66
+ quantized_out += layer.indices_to_codes(indices)
67
+ return quantized_out
68
+
69
+
70
+ class FSQuantizer(nn.Module):
71
+ """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
72
+
73
+ Code adapted from Jax version in Appendix A.1.
74
+
75
+ Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/
76
+ vector_quantize_pytorch/finite_scalar_quantization.py
77
+ [Copyright (c) 2020 Phil Wang]
78
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ levels: list[int],
84
+ dim: Optional[int] = None,
85
+ num_codebooks=1,
86
+ keep_num_codebooks_dim: Optional[bool] = None,
87
+ scale: Optional[float] = None,
88
+ **ignore_kwargs,
89
+ ):
90
+ super().__init__()
91
+ self.dtype = ignore_kwargs.get("dtype", torch.bfloat16)
92
+ _levels = torch.tensor(levels, dtype=torch.int32)
93
+ self.register_buffer("_levels", _levels, persistent=False)
94
+
95
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
96
+ self.register_buffer("_basis", _basis, persistent=False)
97
+
98
+ self.scale = scale
99
+
100
+ codebook_dim = len(levels)
101
+ self.codebook_dim = codebook_dim
102
+
103
+ effective_codebook_dim = codebook_dim * num_codebooks
104
+ self.num_codebooks = num_codebooks
105
+ self.effective_codebook_dim = effective_codebook_dim
106
+
107
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
108
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
109
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
110
+
111
+ self.dim = default(dim, len(_levels) * num_codebooks)
112
+
113
+ has_projections = self.dim != effective_codebook_dim
114
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
115
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
116
+ self.has_projections = has_projections
117
+
118
+ self.codebook_size = self._levels.prod().item()
119
+
120
+ implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
121
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
122
+
123
+ def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
124
+ """Bound `z`, an array of shape (..., d)."""
125
+ half_l = (self._levels - 1) * (1 + eps) / 2
126
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
127
+ shift = (offset / half_l).atanh()
128
+ return (z + shift).tanh() * half_l - offset
129
+
130
+ def quantize(self, z: torch.Tensor) -> torch.Tensor:
131
+ """Quantizes z, returns quantized zhat, same shape as z."""
132
+ quantized = round_ste(self.bound(z))
133
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
134
+ return quantized / half_width
135
+
136
+ def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
137
+ half_width = self._levels // 2
138
+ return (zhat_normalized * half_width) + half_width
139
+
140
+ def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
141
+ half_width = self._levels // 2
142
+ return (zhat - half_width) / half_width
143
+
144
+ def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
145
+ """Converts a `code` to an index in the codebook."""
146
+ assert zhat.shape[-1] == self.codebook_dim
147
+ zhat = self._scale_and_shift(zhat).float()
148
+ return (zhat * self._basis).sum(dim=-1).to(torch.int32)
149
+
150
+ def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor:
151
+ """Inverse of `codes_to_indices`."""
152
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
153
+ indices = rearrange(indices, "... -> ... 1")
154
+ codes_non_centered = (indices // self._basis) % self._levels
155
+ codes = self._scale_and_shift_inverse(codes_non_centered)
156
+
157
+ if self.keep_num_codebooks_dim:
158
+ codes = rearrange(codes, "... c d -> ... (c d)")
159
+
160
+ if project_out:
161
+ codes = self.project_out(codes)
162
+
163
+ if is_img_or_video:
164
+ codes = rearrange(codes, "b ... d -> b d ...")
165
+
166
+ return codes.to(self.dtype)
167
+
168
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
169
+ """
170
+ einstein notation
171
+ b - batch
172
+ n - sequence (or flattened spatial dimensions)
173
+ d - feature dimension, which is also log2(codebook size)
174
+ c - number of codebook dim
175
+ """
176
+ is_img_or_video = z.ndim >= 4
177
+
178
+ # standardize image or video into (batch, seq, dimension)
179
+
180
+ if is_img_or_video:
181
+ z = rearrange(z, "b d ... -> b ... d")
182
+ z, ps = pack_one(z, "b * d")
183
+
184
+ assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
185
+
186
+ z = self.project_in(z)
187
+
188
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
189
+
190
+ codes = self.quantize(z)
191
+ indices = self.codes_to_indices(codes)
192
+
193
+ codes = rearrange(codes, "b n c d -> b n (c d)")
194
+
195
+ out = self.project_out(codes)
196
+
197
+ # reconstitute image or video dimensions
198
+
199
+ if is_img_or_video:
200
+ out = unpack_one(out, ps, "b * d")
201
+ out = rearrange(out, "b ... d -> b d ...")
202
+ indices = unpack_one(indices, ps, "b * c")
203
+ dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True))
204
+ else:
205
+ dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1)
206
+
207
+ if not self.keep_num_codebooks_dim:
208
+ indices = rearrange(indices, "... 1 -> ...")
209
+
210
+ return (indices, out.to(self.dtype), dummy_loss)
211
+
212
+
213
+ class VectorQuantizer(nn.Module):
214
+ """Improved version over VectorQuantizer. Mostly
215
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
216
+
217
+ Adapted from: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/
218
+ taming/modules/vqvae/quantize.py
219
+
220
+ [Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer]
221
+ https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/License.txt
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ num_embeddings: int,
227
+ embedding_dim: int,
228
+ beta: float = 0.25,
229
+ remap: str = None,
230
+ unknown_index: str = "random",
231
+ sane_index_shape: bool = False,
232
+ legacy: bool = True,
233
+ use_norm=False,
234
+ **ignore_kwargs,
235
+ ):
236
+ super().__init__()
237
+ self.n_e = num_embeddings
238
+ self.e_dim = embedding_dim
239
+ self.beta = beta
240
+ self.legacy = legacy
241
+ self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x
242
+
243
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
244
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
245
+
246
+ self.remap = remap
247
+ if self.remap is not None:
248
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
249
+ self.re_embed = self.used.shape[0]
250
+ self.unknown_index = unknown_index
251
+ if self.unknown_index == "extra":
252
+ self.unknown_index = self.re_embed
253
+ self.re_embed = self.re_embed + 1
254
+ print(
255
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
256
+ f"Using {self.unknown_index} for unknown indices."
257
+ )
258
+ else:
259
+ self.re_embed = num_embeddings
260
+
261
+ self.sane_index_shape = sane_index_shape
262
+ self.dtype = ignore_kwargs.get("dtype", torch.float32)
263
+
264
+ def remap_to_used(self, inds):
265
+ ishape = inds.shape
266
+ assert len(ishape) > 1
267
+ inds = inds.reshape(ishape[0], -1)
268
+ used = self.used.to(inds)
269
+ match = (inds[:, :, None] == used[None, None, ...]).long()
270
+ new = match.argmax(-1)
271
+ unknown = match.sum(2) < 1
272
+ if self.unknown_index == "random":
273
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
274
+ else:
275
+ new[unknown] = self.unknown_index
276
+ return new.reshape(ishape)
277
+
278
+ def unmap_to_all(self, inds):
279
+ ishape = inds.shape
280
+ assert len(ishape) > 1
281
+ inds = inds.reshape(ishape[0], -1)
282
+ used = self.used.to(inds)
283
+ if self.re_embed > self.used.shape[0]: # extra token
284
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
285
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
286
+ return back.reshape(ishape)
287
+
288
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
289
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
290
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
291
+ assert return_logits is False, "Only for interface compatible with Gumbel"
292
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
293
+ z_flattened = z.view(-1, self.e_dim)
294
+
295
+ d = (
296
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
297
+ + torch.sum(self.embedding.weight**2, dim=1)
298
+ - 2
299
+ * torch.einsum(
300
+ "bd,dn->bn",
301
+ z_flattened,
302
+ rearrange(self.embedding.weight, "n d -> d n"),
303
+ )
304
+ )
305
+
306
+ encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
307
+ encodings = torch.zeros(encoding_indices.shape[0], self.n_e, device=z.device)
308
+ encodings.scatter_(1, encoding_indices, 1)
309
+ z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
310
+ min_encodings = None
311
+
312
+ z_q, z = self.norm(z_q), self.norm(z)
313
+
314
+ # compute loss for embedding
315
+ commit_loss = torch.mean((z_q - z.detach()) ** 2, dim=[1, 2, 3], keepdim=True)
316
+ emb_loss = torch.mean((z_q.detach() - z) ** 2, dim=[1, 2, 3], keepdim=True)
317
+ if not self.legacy:
318
+ loss = self.beta * emb_loss + commit_loss
319
+ else:
320
+ loss = emb_loss + self.beta * commit_loss
321
+
322
+ # preserve gradients
323
+ z_q = z + (z_q - z).detach()
324
+ avg_probs = torch.mean(encodings, dim=0)
325
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
326
+
327
+ # reshape back to match original input shape
328
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
329
+
330
+ if self.remap is not None:
331
+ min_encoding_indices = encoding_indices.squeeze(1).reshape(z.shape[0], -1) # add batch axis
332
+ min_encoding_indices = self.remap_to_used(encoding_indices.squeeze(1))
333
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
334
+
335
+ if self.sane_index_shape:
336
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
337
+
338
+ # TODO: return (indices, z_q, loss)
339
+ return (
340
+ z_q,
341
+ loss,
342
+ (
343
+ encoding_indices.squeeze(1),
344
+ min_encodings,
345
+ commit_loss.mean().detach(),
346
+ self.beta * emb_loss.mean().detach(),
347
+ perplexity.mean().detach(),
348
+ ),
349
+ )
350
+
351
+ def get_codebook_entry(self, indices, shape):
352
+ # shape specifying (batch, height, width, channel)
353
+ if self.remap is not None:
354
+ indices = indices.reshape(shape[0], -1) # add batch axis
355
+ indices = self.unmap_to_all(indices)
356
+ indices = indices.reshape(-1) # flatten again
357
+
358
+ # get quantized latent vectors
359
+ z_q = self.embedding(indices)
360
+
361
+ if shape is not None:
362
+ z_q = z_q.view(shape)
363
+ # reshape back to match original input shape
364
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
365
+
366
+ return z_q
367
+
368
+
369
+ class LFQuantizer(nn.Module):
370
+ """Lookup-Free Quantization
371
+
372
+ Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/
373
+ vector_quantize_pytorch/lookup_free_quantization.py
374
+ [Copyright (c) 2020 Phil Wang]
375
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ *,
381
+ codebook_size: int,
382
+ codebook_dim: int,
383
+ embed_dim: Optional[int] = None, # if None, use codebook_dim
384
+ entropy_loss_weight=0.1,
385
+ commitment_loss_weight=0.25,
386
+ default_temp: float = 0.01,
387
+ entropy_loss: bool = False,
388
+ **ignore_kwargs,
389
+ ):
390
+ """Lookup-Free Quantization
391
+
392
+ Args:
393
+ codebook_size (int): The number of entries in the codebook.
394
+ codebook_dim (int): The number of bits in each code.
395
+ embed_dim (Optional[int], optional): The dimension of the input embedding. Defaults to None.
396
+ entropy_loss_weight (float, optional): Whether to use entropy loss. Defaults to 0.1.
397
+ commitment_loss_weight (float, optional): Weight for commitment loss. Defaults to 0.25.
398
+ default_temp (float, optional): The temprature to use. Defaults to 0.01.
399
+ entropy_loss (bool, optional): Flag for entropy loss. Defaults to False.
400
+ """
401
+ super().__init__()
402
+ self.entropy_loss = entropy_loss
403
+ self.codebook_dim = codebook_dim
404
+ self.default_temp = default_temp
405
+ self.entrop_loss_weight = entropy_loss_weight
406
+ self.commitment_loss_weight = commitment_loss_weight
407
+ embed_dim = embed_dim or codebook_dim
408
+
409
+ has_projections = embed_dim != codebook_dim
410
+ self.project_in = nn.Linear(embed_dim, codebook_dim) if has_projections else nn.Identity()
411
+ self.project_out = nn.Linear(codebook_dim, embed_dim) if has_projections else nn.Identity()
412
+ logging.info(f"LFQ: has_projections={has_projections}, dim_in={embed_dim}, codebook_dim={codebook_dim}")
413
+
414
+ self.dtype = ignore_kwargs.get("dtype", torch.float32)
415
+
416
+ if entropy_loss:
417
+ assert 2**codebook_dim == codebook_size, "codebook size must be 2 ** codebook_dim"
418
+ self.codebook_size = codebook_size
419
+
420
+ self.register_buffer(
421
+ "mask",
422
+ 2 ** torch.arange(codebook_dim - 1, -1, -1),
423
+ persistent=False,
424
+ )
425
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
426
+
427
+ all_codes = torch.arange(codebook_size)
428
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
429
+ codebook = 2 * bits - 1.0
430
+
431
+ self.register_buffer("codebook", codebook, persistent=False) # [codebook_size, codebook_dim]
432
+
433
+ def forward(self, z: torch.Tensor, temp: float = None) -> torch.Tensor:
434
+ temp = temp or self.default_temp
435
+
436
+ z = rearrange(z, "b d ... -> b ... d")
437
+ z, ps = pack_one(z, "b * d")
438
+ z = self.project_in(z)
439
+
440
+ # split out number of codebooks
441
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
442
+
443
+ # quantization
444
+ original_input = z
445
+
446
+ codebook_value = torch.ones_like(z)
447
+ z_q = torch.where(z > 0, codebook_value, -codebook_value)
448
+
449
+ # preserve gradients
450
+ z_q = z + (z_q - z).detach()
451
+
452
+ # commit loss
453
+ commit_loss = ((original_input - z_q.detach()) ** 2).mean(dim=[1, 2, 3])
454
+
455
+ z_q = rearrange(z_q, "b n c d -> b n (c d)")
456
+ z_q = self.project_out(z_q)
457
+
458
+ # reshape
459
+ z_q = unpack_one(z_q, ps, "b * d")
460
+ z_q = rearrange(z_q, "b ... d -> b d ...")
461
+
462
+ loss = self.commitment_loss_weight * commit_loss
463
+
464
+ # entropy loss (eq-5)
465
+ if self.entropy_loss:
466
+ # indices
467
+ indices = reduce((z > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
468
+ indices = unpack_one(indices, ps, "b * c")
469
+ indices = rearrange(indices, "... 1 -> ...")
470
+
471
+ distance = -2 * torch.einsum(
472
+ "... i d, j d -> ... i j",
473
+ original_input,
474
+ self.codebook.to(original_input.dtype),
475
+ )
476
+ prob = (-distance / temp).softmax(dim=-1)
477
+ per_sample_entropy = entropy(prob).mean(dim=[1, 2])
478
+ avg_prob = reduce(prob, "... c d -> c d", "mean")
479
+ codebook_entropy = entropy(avg_prob).mean()
480
+ entropy_aux_loss = per_sample_entropy - codebook_entropy
481
+
482
+ loss += self.entrop_loss_weight * entropy_aux_loss
483
+
484
+ # TODO: return (indices, z_q, loss)
485
+ return (
486
+ z_q,
487
+ loss.unsqueeze(1).unsqueeze(1).unsqueeze(1),
488
+ (
489
+ indices,
490
+ self.commitment_loss_weight * commit_loss.mean().detach(),
491
+ self.entrop_loss_weight * entropy_aux_loss.mean().detach(),
492
+ self.entrop_loss_weight * per_sample_entropy.mean().detach(),
493
+ self.entrop_loss_weight * codebook_entropy.mean().detach(),
494
+ ),
495
+ )
496
+ else:
497
+ return (
498
+ z_q,
499
+ loss.unsqueeze(1).unsqueeze(1).unsqueeze(1),
500
+ self.commitment_loss_weight * commit_loss.mean().detach(),
501
+ )
502
+
503
+
504
+ class InvQuantizerJit(nn.Module):
505
+ """Use for decoder_jit to trace quantizer in discrete tokenizer"""
506
+
507
+ def __init__(self, quantizer):
508
+ super().__init__()
509
+ self.quantizer = quantizer
510
+
511
+ def forward(self, indices: torch.Tensor):
512
+ codes = self.quantizer.indices_to_codes(indices)
513
+ return codes.to(self.quantizer.dtype)
cosmos_transfer1/auxiliary/tokenizer/modules/utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Shared utilities for the networks module."""
17
+
18
+ from typing import Any
19
+
20
+ import torch
21
+ from einops import pack, rearrange, unpack
22
+
23
+
24
+ def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
25
+ batch_size = x.shape[0]
26
+ return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
27
+
28
+
29
+ def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
30
+ return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
31
+
32
+
33
+ def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
34
+ batch_size, height = x.shape[0], x.shape[-2]
35
+ return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
36
+
37
+
38
+ def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
39
+ return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
40
+
41
+
42
+ def cast_tuple(t: Any, length: int = 1) -> Any:
43
+ return t if isinstance(t, tuple) else ((t,) * length)
44
+
45
+
46
+ def replication_pad(x):
47
+ return torch.cat([x[:, :, :1, ...], x], dim=2)
48
+
49
+
50
+ def divisible_by(num: int, den: int) -> bool:
51
+ return (num % den) == 0
52
+
53
+
54
+ def is_odd(n: int) -> bool:
55
+ return not divisible_by(n, 2)
56
+
57
+
58
+ def nonlinearity(x):
59
+ return x * torch.sigmoid(x)
60
+
61
+
62
+ def Normalize(in_channels, num_groups=32):
63
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
64
+
65
+
66
+ class CausalNormalize(torch.nn.Module):
67
+ def __init__(self, in_channels, num_groups=1):
68
+ super().__init__()
69
+ self.norm = torch.nn.GroupNorm(
70
+ num_groups=num_groups,
71
+ num_channels=in_channels,
72
+ eps=1e-6,
73
+ affine=True,
74
+ )
75
+ self.num_groups = num_groups
76
+
77
+ def forward(self, x):
78
+ # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
79
+ # All new models should use num_groups=1, otherwise causality is not guaranteed.
80
+ if self.num_groups == 1:
81
+ x, batch_size = time2batch(x)
82
+ return batch2time(self.norm(x), batch_size)
83
+ return self.norm(x)
84
+
85
+
86
+ def exists(v):
87
+ return v is not None
88
+
89
+
90
+ def default(*args):
91
+ for arg in args:
92
+ if exists(arg):
93
+ return arg
94
+ return None
95
+
96
+
97
+ def pack_one(t, pattern):
98
+ return pack([t], pattern)
99
+
100
+
101
+ def unpack_one(t, ps, pattern):
102
+ return unpack(t, ps, pattern)[0]
103
+
104
+
105
+ def round_ste(z: torch.Tensor) -> torch.Tensor:
106
+ """Round with straight through gradients."""
107
+ zhat = z.round()
108
+ return z + (zhat - z).detach()
109
+
110
+
111
+ def log(t, eps=1e-5):
112
+ return t.clamp(min=eps).log()
113
+
114
+
115
+ def entropy(prob):
116
+ return (-prob * log(prob)).sum(dim=-1)
cosmos_transfer1/auxiliary/tokenizer/networks/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum
17
+
18
+ from cosmos_transfer1.auxiliary.tokenizer.networks.configs import continuous_image as continuous_image_dict
19
+ from cosmos_transfer1.auxiliary.tokenizer.networks.configs import continuous_video as continuous_video_dict
20
+ from cosmos_transfer1.auxiliary.tokenizer.networks.configs import discrete_image as discrete_image_dict
21
+ from cosmos_transfer1.auxiliary.tokenizer.networks.configs import discrete_video as discrete_video_dict
22
+ from cosmos_transfer1.auxiliary.tokenizer.networks.continuous_image import ContinuousImageTokenizer
23
+ from cosmos_transfer1.auxiliary.tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer
24
+ from cosmos_transfer1.auxiliary.tokenizer.networks.discrete_image import DiscreteImageTokenizer
25
+ from cosmos_transfer1.auxiliary.tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer
26
+
27
+
28
+ class TokenizerConfigs(Enum):
29
+ CI = continuous_image_dict
30
+ DI = discrete_image_dict
31
+ CV = continuous_video_dict
32
+ DV = discrete_video_dict
33
+
34
+
35
+ class TokenizerModels(Enum):
36
+ CI = ContinuousImageTokenizer
37
+ DI = DiscreteImageTokenizer
38
+ CV = CausalContinuousVideoTokenizer
39
+ DV = CausalDiscreteVideoTokenizer
cosmos_transfer1/auxiliary/tokenizer/networks/configs.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The default image and video tokenizer configs."""
17
+
18
+ from cosmos_transfer1.auxiliary.tokenizer.modules import (
19
+ ContinuousFormulation,
20
+ Decoder3DType,
21
+ DecoderType,
22
+ DiscreteQuantizer,
23
+ Encoder3DType,
24
+ EncoderType,
25
+ )
26
+
27
+ continuous_image = dict(
28
+ # The attention resolution for res blocks.
29
+ attn_resolutions=[32],
30
+ # The base number of channels.
31
+ channels=128,
32
+ # The channel multipler for each resolution.
33
+ channels_mult=[2, 4, 4],
34
+ dropout=0.0,
35
+ in_channels=3,
36
+ # The spatial compression ratio.
37
+ spatial_compression=16,
38
+ # The number of layers in each res block.
39
+ num_res_blocks=2,
40
+ out_channels=3,
41
+ resolution=1024,
42
+ patch_size=4,
43
+ patch_method="haar",
44
+ # The output latent dimension (channels).
45
+ latent_channels=16,
46
+ # The encoder output channels just before sampling.
47
+ # Which is also the decoder's input channels.
48
+ z_channels=16,
49
+ # A factor over the z_channels, to get the total channels the encoder should output.
50
+ # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels.
51
+ z_factor=1,
52
+ name="CI",
53
+ # What formulation to use, either "AE" or "VAE".
54
+ # Chose VAE here, since the pre-trained ckpt were of a VAE formulation.
55
+ formulation=ContinuousFormulation.AE.name,
56
+ # Specify type of encoder ["Default", "LiteVAE"]
57
+ encoder=EncoderType.Default.name,
58
+ # Specify type of decoder ["Default"]
59
+ decoder=DecoderType.Default.name,
60
+ )
61
+
62
+ discrete_image = dict(
63
+ # The attention resolution for res blocks.
64
+ attn_resolutions=[32],
65
+ # The base number of channels.
66
+ channels=128,
67
+ # The channel multipler for each resolution.
68
+ channels_mult=[2, 4, 4],
69
+ dropout=0.0,
70
+ in_channels=3,
71
+ # The spatial compression ratio.
72
+ spatial_compression=16,
73
+ # The number of layers in each res block.
74
+ num_res_blocks=2,
75
+ out_channels=3,
76
+ resolution=1024,
77
+ patch_size=4,
78
+ patch_method="haar",
79
+ # The encoder output channels just before sampling.
80
+ z_channels=256,
81
+ # A factor over the z_channels, to get the total channels the encoder should output.
82
+ # for discrete tokenization, often we directly use the vector, so z_factor=1.
83
+ z_factor=1,
84
+ # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ.
85
+ quantizer=DiscreteQuantizer.FSQ.name,
86
+ # The embedding dimension post-quantization, which is also the input channels of the decoder.
87
+ # Which is also the output
88
+ embedding_dim=6,
89
+ # The number of levels to use for fine-scalar quantization.
90
+ levels=[8, 8, 8, 5, 5, 5],
91
+ # The number of quantizers to use for residual fine-scalar quantization.
92
+ num_quantizers=4,
93
+ name="DI",
94
+ # Specify type of encoder ["Default", "LiteVAE"]
95
+ encoder=EncoderType.Default.name,
96
+ # Specify type of decoder ["Default"]
97
+ decoder=DecoderType.Default.name,
98
+ )
99
+
100
+ continuous_video = dict(
101
+ attn_resolutions=[32],
102
+ channels=128,
103
+ channels_mult=[2, 4, 4],
104
+ dropout=0.0,
105
+ in_channels=3,
106
+ num_res_blocks=2,
107
+ out_channels=3,
108
+ resolution=1024,
109
+ patch_size=4,
110
+ patch_method="haar",
111
+ latent_channels=16,
112
+ z_channels=16,
113
+ z_factor=1,
114
+ num_groups=1,
115
+ legacy_mode=False,
116
+ spatial_compression=8,
117
+ temporal_compression=8,
118
+ formulation=ContinuousFormulation.AE.name,
119
+ encoder=Encoder3DType.FACTORIZED.name,
120
+ decoder=Decoder3DType.FACTORIZED.name,
121
+ name="CV",
122
+ )
123
+
124
+ discrete_video = dict(
125
+ attn_resolutions=[32],
126
+ channels=128,
127
+ channels_mult=[2, 4, 4],
128
+ dropout=0.0,
129
+ in_channels=3,
130
+ num_res_blocks=2,
131
+ out_channels=3,
132
+ resolution=1024,
133
+ patch_size=4,
134
+ patch_method="haar",
135
+ z_channels=16,
136
+ z_factor=1,
137
+ num_groups=1,
138
+ legacy_mode=False,
139
+ spatial_compression=16,
140
+ temporal_compression=8,
141
+ quantizer=DiscreteQuantizer.FSQ.name,
142
+ embedding_dim=6,
143
+ levels=[8, 8, 8, 5, 5, 5],
144
+ encoder=Encoder3DType.FACTORIZED.name,
145
+ decoder=Decoder3DType.FACTORIZED.name,
146
+ name="DV",
147
+ )