hanjiaming.0208 commited on
Commit
146dae5
·
1 Parent(s): 5c9a353
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ examples/animal.png filter=lfs diff=lfs merge=lfs -text
36
+ examples/bell_ring.wav filter=lfs diff=lfs merge=lfs -text
37
+ examples/caixukun.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/flower.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ examples/food_menu.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/star_kun.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ model/lib/pointnet2/build/lib.linux-x86_64-cpython-39/pointnet2_cuda.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
42
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/ball_query.o filter=lfs diff=lfs merge=lfs -text
43
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/group_points.o filter=lfs diff=lfs merge=lfs -text
44
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/interpolate.o filter=lfs diff=lfs merge=lfs -text
45
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/pointnet2_api.o filter=lfs diff=lfs merge=lfs -text
46
+ model/lib/pointnet2/build/temp.linux-x86_64-cpython-39/src/sampling.o filter=lfs diff=lfs merge=lfs -text
47
+ model/lib/pointnet2/dist/pointnet2-0.0.0-py3.9-linux-x86_64.egg filter=lfs diff=lfs merge=lfs -text
48
+ model/lib/pointnet2_cuda.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.egg-info
4
+ dist
5
+
6
+ output
7
+ output_dir
8
+ *.pth
9
+ *.log
10
+ weights
README.md CHANGED
@@ -1,14 +1,46 @@
1
  ---
2
- title: Tar 7B
3
- emoji: 🦀
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  short_description: Unified MLLM with Text-Aligned Representations
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Tar
3
+ emoji: 🚀
4
+ colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.34.0
8
  app_file: app.py
9
  pinned: false
10
+ python_version: 3.10.18
11
  short_description: Unified MLLM with Text-Aligned Representations
12
+ license: apache-2.0
13
  ---
14
 
15
+ ### Unifying Visual Understanding and Generation via Text-Aligned Representations
16
+ > [Jiaming Han](https://csuhan.com), [Hao Chen](https://haochen-rye.github.io)<sup>†</sup>, [Yang Zhao](https://scholar.google.com/citations?user=uPmTOHAAAAAJ&hl=zh-CN), [Hanyu Wang](https://hywang66.github.io), [Qi Zhao](https://kevinz8866.github.io), [Ziyan Yang](https://ziyanyang.github.io), [Hao He](https://hehao13.github.io), [Xiangyu Yue](https://xyue.io)<sup>‡</sup>, [Lu Jiang](https://www.lujiang.info)<sup>‡</sup>
17
+ >
18
+ > <sup>†</sup> Project Lead&nbsp;&nbsp;<sup>‡</sup> Corresponding Authors
19
+
20
+ <a href="https://tar.csuhan.com">
21
+ <img
22
+ src="https://img.shields.io/badge/Project-Page-0A66C2?logo=chromewebstore&logoColor=0A66C2"
23
+ alt="Project Page"
24
+ />
25
+ </a>
26
+ <a href="http://arxiv.org/abs/2506.18898">
27
+ <img
28
+ src="https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=red"
29
+ alt="Tar Paper on arXiv"
30
+ />
31
+ </a>
32
+
33
+
34
+ ### Citation
35
+ ```
36
+ @article{han2025tar,
37
+ title={Vision as a Dialect: Unifying Visual Understanding and Generation via Text-Aligned Representations},
38
+ author={Han, Jiaming and Chen, Hao and Zhao, Yang and Wang, Hanyu and Zhao, Qi and Yang, Ziyan and He, Hao and Yue, Xiangyu and Jiang, Lu},
39
+ journal={arXiv preprint arXiv:2506.18898},
40
+ year={2025},
41
+ }
42
+ ```
43
+
44
+ ### License
45
+ This project is licensed under the Apache 2.0 License.
46
+
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import os
16
+ import gradio as gr
17
+ from torchvision.transforms.functional import to_tensor
18
+ from huggingface_hub import hf_hub_download, snapshot_download, login
19
+ import spaces
20
+
21
+ from tok.ar_dtok.ar_model import ARModel
22
+ from t2i_inference import T2IConfig, TextToImageInference
23
+
24
+ def generate_text(self, image: str, prompt: str) -> str:
25
+ image = image.convert('RGB')
26
+ image = to_tensor(image).unsqueeze(0).to(self.device)
27
+
28
+ image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep']
29
+ image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
30
+
31
+ messages = [
32
+ {"role": "system", "content": "You are a helpful assistant."},
33
+ {"role": "user", "content": f"{image_text}\n{prompt}"}
34
+ ]
35
+
36
+ input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
37
+ inputs = self.tokenizer(input_text, return_tensors="pt")
38
+
39
+ gen_ids = self.model.generate(
40
+ inputs.input_ids.to(self.device),
41
+ max_new_tokens=512,
42
+ do_sample=True)
43
+ return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
44
+
45
+ login(token=os.getenv('HF_TOKEN'))
46
+ config = T2IConfig()
47
+ config.model = snapshot_download("ByteDance-Seed/Tar-7B")
48
+ config.ar_path = {
49
+ "1024px": hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ar_dtok_lp_1024px.pth"),
50
+ "512px": hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ar_dtok_lp_512px.pth"),
51
+ }
52
+ config.encoder_path = hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ta_tok.pth")
53
+ config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
54
+ inference = TextToImageInference(config)
55
+
56
+ @spaces.GPU(duration=120)
57
+ def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
58
+ image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
59
+ return image
60
+
61
+ def clear_inputs_t2i():
62
+ return "", None
63
+
64
+ @spaces.GPU(duration=120)
65
+ def understand_image(image, prompt):
66
+ return generate_text(inference, image, prompt)
67
+
68
+ def clear_inputs_i2t():
69
+ return None, "", ""
70
+
71
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
72
+ gr.Markdown(
73
+ """
74
+ <div align="center">
75
+
76
+ ### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
77
+
78
+ [🕸️ Project Page](http://tar.csuhan.com) • [📄 Paper](http://arxiv.org/abs/2506.18898) • [💻 Code](https://github.com/csuhan/Tar) • [📦 Model](https://huggingface.co/collections/ByteDance-Seed/tar-6864cf0d9fe59a3b91cc4260)
79
+
80
+ </div>
81
+ """,
82
+ elem_id="title",
83
+ )
84
+ with gr.Tab("Image Generation"):
85
+ with gr.Row():
86
+ with gr.Column(scale=1):
87
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
88
+ with gr.Accordion("Advanced Settings", open=False):
89
+ resolution = gr.Radio(
90
+ ["512px", "1024px"], value="1024px", label="Resolution"
91
+ )
92
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
93
+ top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
94
+ cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
95
+ with gr.Row():
96
+ generate_btn = gr.Button("Generate")
97
+ clear_btn = gr.Button("Clear")
98
+ with gr.Column(scale=2):
99
+ output_image = gr.Image(label="Generated Image")
100
+
101
+ generate_btn.click(
102
+ generate_image,
103
+ inputs=[prompt, resolution, top_p, top_k, cfg_scale],
104
+ outputs=output_image
105
+ )
106
+ clear_btn.click(
107
+ clear_inputs_t2i,
108
+ outputs=[prompt, output_image]
109
+ )
110
+
111
+ with gr.Tab("Image Understanding"):
112
+ with gr.Row():
113
+ with gr.Column(scale=1):
114
+ image_input = gr.Image(label="Upload Image", type="pil")
115
+ question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
116
+ with gr.Row():
117
+ qa_btn = gr.Button("Generate")
118
+ clear_btn_i2t = gr.Button("Clear")
119
+ with gr.Column(scale=1):
120
+ answer_output = gr.Textbox(label="Response", lines=4)
121
+
122
+ qa_btn.click(
123
+ understand_image,
124
+ inputs=[image_input, question_input],
125
+ outputs=answer_output
126
+ )
127
+
128
+ clear_btn_i2t.click(
129
+ clear_inputs_i2t,
130
+ outputs=[image_input, question_input, answer_output]
131
+ )
132
+
133
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ accelerate==0.28.0
3
+ datasets==2.16.1
4
+ deepspeed==0.14.4
5
+ einops==0.8.1
6
+ gradio==5.34.0
7
+ huggingface_hub==0.29.1
8
+ numpy==1.26.1
9
+ Pillow==11.2.1
10
+ pyarrow==17.0.0
11
+ PyYAML==6.0.2
12
+ torch==2.1.2
13
+ torchvision==0.16.2
14
+ tqdm==4.66.5
15
+ transformers==4.50.0
16
+ wandb
17
+ easydict
t2i_inference.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+
18
+ import torch
19
+ from huggingface_hub import hf_hub_download
20
+ from PIL import Image
21
+ from transformers import AutoTokenizer, Qwen2ForCausalLM
22
+
23
+ from tok.mm_autoencoder import MMAutoEncoder
24
+
25
+
26
+ @dataclass
27
+ class T2IConfig:
28
+ model_path: str = "ByteDance-Seed/Tar-1.5B"
29
+ # visual tokenizer config
30
+ ar_path = None
31
+ encoder_path: str = 'ta_tok.pth'
32
+ decoder_path: str = 'vq_ds16_t2i.pt'
33
+
34
+ device: str = "cuda:0"
35
+ dtype: torch.dtype = torch.bfloat16
36
+ # generation parameters
37
+ scale: int = 0 # choose from [0, 1, 2]
38
+ seq_len: int = 729 # choose from [729, 169, 81]
39
+ temperature: float = 1.0
40
+ top_p: float = 0.95
41
+ top_k: int = 1200
42
+ cfg_scale: float = 4.0
43
+
44
+ class TextToImageInference:
45
+ def __init__(self, config: T2IConfig):
46
+ self.config = config
47
+ self.device = torch.device(config.device)
48
+ self._load_models()
49
+
50
+ def _load_models(self):
51
+ self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device)
52
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
53
+
54
+ # Initialize visual tokenizer
55
+ config = dict(
56
+ ar_path_dict=self.config.ar_path,
57
+ encoder_path=self.config.encoder_path,
58
+ decoder_path=self.config.decoder_path,
59
+ encoder_args={'input_type': 'rec'},
60
+ decoder_args={},
61
+ )
62
+ self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
63
+ for ar_model in self.visual_tokenizer.ar_model.values():
64
+ ar_model.cls_token_num = self.config.seq_len
65
+ self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
66
+
67
+ def generate_image(self, prompt, resolution, top_p, top_k, cfg_scale) -> Image.Image:
68
+ # Prepare prompt
69
+ messages = [
70
+ {"role": "system", "content": "You are a helpful assistant."},
71
+ {"role": "user", "content": prompt}
72
+ ]
73
+
74
+ input_text = self.tokenizer.apply_chat_template(
75
+ messages,
76
+ tokenize=False,
77
+ add_generation_prompt=True)
78
+ input_text += f"<im_start><S{self.config.scale}>"
79
+
80
+ # Generate tokens
81
+ inputs = self.tokenizer(input_text, return_tensors="pt")
82
+ gen_ids = self.model.generate(
83
+ inputs.input_ids.to(self.device),
84
+ max_new_tokens=self.config.seq_len,
85
+ do_sample=True,
86
+ temperature=self.config.temperature,
87
+ top_p=top_p,
88
+ top_k=top_k)
89
+
90
+ # Process generated tokens
91
+ gen_text = self.tokenizer.batch_decode(gen_ids)[0]
92
+ gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)]
93
+ gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code))
94
+ gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device)
95
+
96
+ gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
97
+ gen_code,
98
+ {'cfg_scale': cfg_scale, 'resolution': resolution},
99
+ )
100
+ gen_image = Image.fromarray(gen_tensor[0].numpy())
101
+ return gen_image
tok/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from .ar_dtok import *
tok/ar_dtok/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from .bottleneck import Bottleneck, SimVectorQuantizer
16
+ from .vqvae import VQVAE
tok/ar_dtok/ar_model.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import os
16
+ from contextlib import contextmanager
17
+ from dataclasses import dataclass
18
+ from typing import Optional
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from .. import models
26
+ from .generate import generate as ar_generate
27
+
28
+
29
+ def find_multiple(n: int, k: int):
30
+ if n % k == 0:
31
+ return n
32
+ return n + k - (n % k)
33
+
34
+
35
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, scale_factor=10000):
36
+ """
37
+ embed_dim: output dimension for each position
38
+ pos: a list of positions to be encoded: size (M,)
39
+ out: (M, D)
40
+ scale_factor: the base for the scaling factor, default is 10000
41
+ """
42
+ assert embed_dim % 2 == 0
43
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
44
+ omega /= embed_dim / 2.
45
+ omega = 1. / scale_factor**omega # Parameterized scaling factor (D/2,)
46
+
47
+ pos = pos.reshape(-1) # (M,)
48
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
49
+
50
+ emb_sin = np.sin(out) # (M, D/2)
51
+ emb_cos = np.cos(out) # (M, D/2)
52
+
53
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
54
+ return emb
55
+
56
+
57
+ @dataclass
58
+ class ModelArgs:
59
+ dim: int = 4096
60
+ n_layer: int = 32
61
+ n_head: int = 32
62
+
63
+ n_kv_head: Optional[int] = None
64
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
65
+ ffn_dim_multiplier: Optional[float] = None
66
+ rope_base: float = 10000
67
+ norm_eps: float = 1e-5
68
+ initializer_range: float = 0.02
69
+
70
+ token_dropout_p: float = 0.1
71
+ attn_dropout_p: float = 0.0
72
+ resid_dropout_p: float = 0.1
73
+ ffn_dropout_p: float = 0.1
74
+ drop_path_rate: float = 0.0
75
+
76
+ num_classes: int = 1000
77
+ class_dropout_prob: float = 0.1
78
+ model_type: str = 'class_cond' # clip_cond, indice_cond
79
+ cond_dim: int = 1152
80
+ cond_vocab_size: int = 8192
81
+
82
+ vocab_size: int = 8192
83
+ cls_token_num: int = 1
84
+
85
+ max_batch_size: int = 32
86
+ max_seq_len: int = 2048
87
+
88
+ use_fixed_pe: bool = False
89
+
90
+ frame_prediction: bool = False
91
+
92
+
93
+ class RMSNorm(torch.nn.Module):
94
+ def __init__(self, dim: int, eps: float = 1e-5):
95
+ super().__init__()
96
+ self.eps = eps
97
+ self.weight = nn.Parameter(torch.ones(dim))
98
+
99
+ @torch.autocast(device_type='cuda', enabled=False)
100
+ def _norm(self, x):
101
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
102
+
103
+ def forward(self, x):
104
+ output = self._norm(x.float()).type_as(x)
105
+ return output * self.weight
106
+
107
+
108
+ class MLP(nn.Module):
109
+ def __init__(self, in_features, hidden_features, out_features):
110
+ super().__init__()
111
+ out_features = out_features or in_features
112
+ hidden_features = hidden_features or in_features
113
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
114
+ self.act = nn.GELU(approximate='tanh')
115
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
116
+
117
+ def forward(self, x):
118
+ x = self.fc1(x)
119
+ x = self.act(x)
120
+ x = self.fc2(x)
121
+ return x
122
+
123
+
124
+ #################################################################################
125
+ # Drop Path Implementation #
126
+ #################################################################################
127
+
128
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
129
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
130
+
131
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
132
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
133
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
134
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
135
+ 'survival rate' as the argument.
136
+
137
+ """
138
+ if drop_prob == 0. or not training:
139
+ return x
140
+ keep_prob = 1 - drop_prob
141
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
142
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
143
+ if keep_prob > 0.0 and scale_by_keep:
144
+ random_tensor.div_(keep_prob)
145
+ return x * random_tensor
146
+
147
+
148
+ class DropPath(torch.nn.Module):
149
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
150
+ """
151
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
152
+ super(DropPath, self).__init__()
153
+ self.drop_prob = drop_prob
154
+ self.scale_by_keep = scale_by_keep
155
+
156
+ def forward(self, x):
157
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
158
+
159
+ def extra_repr(self):
160
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
161
+
162
+
163
+ #################################################################################
164
+ # AR Model #
165
+ #################################################################################
166
+
167
+ class FeedForward(nn.Module):
168
+ def __init__(self, config: ModelArgs):
169
+ super().__init__()
170
+ hidden_dim = 4 * config.dim
171
+ hidden_dim = int(2 * hidden_dim / 3)
172
+ # custom dim factor multiplier
173
+ if config.ffn_dim_multiplier is not None:
174
+ hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
175
+ hidden_dim = find_multiple(hidden_dim, config.multiple_of)
176
+
177
+ self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
178
+ self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
179
+ self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
180
+ self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
181
+
182
+ def forward(self, x):
183
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
184
+
185
+
186
+ class KVCache(nn.Module):
187
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
188
+ super().__init__()
189
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
190
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
191
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
192
+
193
+ def update(self, input_pos, k_val, v_val):
194
+ # input_pos: [S], k_val: [B, H, S, D]
195
+ assert input_pos.shape[0] == k_val.shape[2], f"{input_pos.shape[0]} != {k_val.shape[2]}"
196
+ k_out = self.k_cache
197
+ v_out = self.v_cache
198
+ k_out[:, :, input_pos] = k_val.to(k_out.dtype)
199
+ v_out[:, :, input_pos] = v_val.to(v_out.dtype)
200
+
201
+ return k_out, v_out
202
+
203
+
204
+ class Attention(nn.Module):
205
+ def __init__(self, config: ModelArgs):
206
+ super().__init__()
207
+ assert config.dim % config.n_head == 0
208
+ self.dim = config.dim
209
+ self.head_dim = config.dim // config.n_head
210
+ self.n_head = config.n_head
211
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
212
+ total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
213
+
214
+ # key, query, value projections for all heads, but in a batch
215
+ self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
216
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
217
+ self.kv_cache = None
218
+
219
+ # regularization
220
+ self.attn_dropout_p = config.attn_dropout_p
221
+ self.resid_dropout = nn.Dropout(config.resid_dropout_p)
222
+
223
+ def forward(
224
+ self, x: torch.Tensor,
225
+ input_pos: Optional[torch.Tensor] = None,
226
+ mask: Optional[torch.Tensor] = None
227
+ ):
228
+ bsz, seqlen, _ = x.shape
229
+ kv_size = self.n_kv_head * self.head_dim
230
+ xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
231
+
232
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
233
+ xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
234
+ xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
235
+
236
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
237
+
238
+ if self.kv_cache is not None:
239
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
240
+ else:
241
+ keys, values = xk, xv
242
+ keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
243
+ values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
244
+
245
+ output = F.scaled_dot_product_attention(
246
+ xq, keys, values,
247
+ attn_mask=mask,
248
+ is_causal=True if mask is None else False, # is_causal=False is for KV cache
249
+ dropout_p=self.attn_dropout_p if self.training else 0)
250
+
251
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
252
+
253
+ output = self.resid_dropout(self.wo(output))
254
+ return output
255
+
256
+
257
+ class TransformerBlock(nn.Module):
258
+ def __init__(self, config: ModelArgs, drop_path: float):
259
+ super().__init__()
260
+ self.attention = Attention(config)
261
+ self.feed_forward = FeedForward(config)
262
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
263
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
264
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
265
+
266
+ def forward(
267
+ self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
268
+ h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
269
+ out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
270
+ return out
271
+
272
+
273
+ class LabelEmbedder(nn.Module):
274
+ """
275
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
276
+ """
277
+ def __init__(self, num_classes, hidden_size, dropout_prob):
278
+ super().__init__()
279
+ use_cfg_embedding = dropout_prob > 0
280
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
281
+ self.num_classes = num_classes
282
+ self.dropout_prob = dropout_prob
283
+
284
+ def token_drop(self, labels, force_drop_ids=None):
285
+ """
286
+ Drops labels to enable classifier-free guidance.
287
+ """
288
+ if force_drop_ids is None:
289
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
290
+ else:
291
+ drop_ids = force_drop_ids == 1
292
+ labels = torch.where(drop_ids, self.num_classes, labels)
293
+ return labels
294
+
295
+ def forward(self, labels, train, force_drop_ids=None):
296
+ use_dropout = self.dropout_prob > 0
297
+ if (train and use_dropout) or (force_drop_ids is not None):
298
+ labels = self.token_drop(labels, force_drop_ids)
299
+
300
+ # replace all negative labels with the last class (unconditional class)
301
+ labels = torch.where(labels < 0, self.num_classes, labels)
302
+ embeddings = self.embedding_table(labels)
303
+ return embeddings
304
+
305
+
306
+ class ARModel(nn.Module):
307
+ def __init__(self, config: ModelArgs):
308
+ super().__init__()
309
+ self.config = config
310
+ self.vocab_size = config.vocab_size
311
+ self.n_layer = config.n_layer
312
+ self.max_seq_length = config.max_seq_len
313
+ self.num_classes = config.num_classes
314
+ self.model_type = config.model_type
315
+ self.cls_token_num = config.cls_token_num
316
+ self.is_sampling = False
317
+ self.frame_prediction = config.frame_prediction
318
+
319
+ if self.model_type == 'class_cond':
320
+ self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
321
+ elif self.model_type == 'clip_cond':
322
+ self.clip_proj = nn.Linear(config.cond_dim, config.dim)
323
+ elif self.model_type == 'indice_cond':
324
+ self.clip_proj = LabelEmbedder(config.cond_vocab_size + 1, config.dim, 0.0)
325
+ else:
326
+ raise Exception("please check model type")
327
+
328
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
329
+ self.tok_dropout = nn.Dropout(config.token_dropout_p)
330
+
331
+ # transformer blocks
332
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
333
+ self.layers = torch.nn.ModuleList()
334
+ for layer_id in range(config.n_layer):
335
+ self.layers.append(TransformerBlock(config, dpr[layer_id]))
336
+
337
+ # output layer
338
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
339
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
340
+
341
+ if config.use_fixed_pe:
342
+ self.register_buffer('abs_pe', torch.zeros(1, config.max_seq_len + config.cls_token_num - 1, config.dim))
343
+ abs_pe = get_1d_sincos_pos_embed_from_grid(embed_dim=config.dim, pos=np.arange(config.max_seq_len + config.cls_token_num - 1))
344
+ self.abs_pe.copy_(torch.from_numpy(abs_pe).float().reshape_as(self.abs_pe))
345
+ print(f"Using fixed absolute PE")
346
+ else:
347
+ self.abs_pe = nn.Parameter(torch.randn(1, config.max_seq_len + config.cls_token_num - 1, config.dim) * 0.02)
348
+ print(f"Using learned absolute PE")
349
+
350
+ self.initialize_weights()
351
+
352
+ def initialize_weights(self):
353
+ # Initialize nn.Linear and nn.Embedding
354
+ self.apply(self._init_weights)
355
+
356
+ # Zero-out output layers:
357
+ if hasattr(self.output, 'weight') and isinstance(self.output.weight, nn.Parameter):
358
+ nn.init.constant_(self.output.weight, 0)
359
+
360
+ def _init_weights(self, module):
361
+ std = self.config.initializer_range
362
+ if isinstance(module, nn.Linear):
363
+ module.weight.data.normal_(mean=0.0, std=std)
364
+ if module.bias is not None:
365
+ module.bias.data.zero_()
366
+ elif isinstance(module, nn.Embedding):
367
+ module.weight.data.normal_(mean=0.0, std=std)
368
+
369
+
370
+ @property
371
+ def device(self):
372
+ return next(self.parameters()).device
373
+
374
+ @property
375
+ def dtype(self):
376
+ return next(self.parameters()).dtype
377
+
378
+
379
+ @contextmanager
380
+ def sampling(self):
381
+ self.is_sampling = True
382
+ try:
383
+ yield
384
+ finally:
385
+ self.is_sampling = False
386
+
387
+
388
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
389
+ assert max_seq_length == self.max_seq_length + self.cls_token_num, f'{max_seq_length} != {self.max_seq_length} + {self.cls_token_num=}'
390
+
391
+ head_dim = self.config.dim // self.config.n_head
392
+ max_seq_length = find_multiple(max_seq_length, 8)
393
+
394
+ for b in self.layers:
395
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
396
+
397
+ causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool))
398
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(max_batch_size, 1, 1)
399
+
400
+
401
+ def reset_caches(self):
402
+ for b in self.layers:
403
+ b.attention.kv_cache = None
404
+
405
+ def clip_embedding(self, x):
406
+ if self.model_type == 'clip_cond':
407
+ if self.training and self.config.class_dropout_prob > 0:
408
+ drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
409
+ x[drop_ids] = 0.
410
+ x = self.clip_proj(x.to(self.dtype)) # Linear
411
+ elif self.model_type == 'indice_cond':
412
+ if self.training and self.config.class_dropout_prob > 0:
413
+ drop_ids = torch.rand(x.shape[0], device=x.device) < self.config.class_dropout_prob
414
+ x[drop_ids] = self.config.cond_vocab_size
415
+ x = self.clip_proj(x, train=self.training) # Embedding
416
+ return x
417
+
418
+ def forward(
419
+ self,
420
+ idx: Optional[torch.Tensor], # (b, n)
421
+ cond_idx: Optional[torch.Tensor], # cond_idx_or_embed
422
+ input_pos: Optional[torch.Tensor] = None,
423
+ targets: Optional[torch.Tensor] = None,
424
+ mask: Optional[torch.Tensor] = None,
425
+ valid: Optional[torch.Tensor] = None,
426
+ ):
427
+ if idx is not None and cond_idx is not None: # training or naive inference
428
+ if self.model_type == 'class_cond':
429
+ cond_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
430
+ elif self.model_type in ['clip_cond', 'indice_cond']:
431
+ cond_embeddings = self.clip_embedding(cond_idx)
432
+ token_embeddings = self.tok_embeddings(idx) # (b, n, d)
433
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1) # (b, cls_token_num + n, d)
434
+ h = self.tok_dropout(token_embeddings)
435
+ else:
436
+ if cond_idx is not None: # prefill in inference
437
+ if self.model_type == 'class_cond':
438
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training).unsqueeze(1)[:,:self.cls_token_num]
439
+ elif self.model_type in ['clip_cond', 'indice_cond']:
440
+ token_embeddings = self.clip_embedding(cond_idx)
441
+ else: # decode_n_tokens(kv cache) in inference
442
+ token_embeddings = self.tok_embeddings(idx)
443
+
444
+ bs = token_embeddings.shape[0]
445
+ mask = self.causal_mask[:bs, None, input_pos]
446
+ h = self.tok_dropout(token_embeddings)
447
+
448
+ if self.is_sampling:
449
+ h = h + self.abs_pe[:, input_pos]
450
+ else:
451
+ h = h + self.abs_pe[:, :h.shape[1]]
452
+
453
+ # transformer blocks
454
+ for layer in self.layers:
455
+ h = layer(h, input_pos, mask)
456
+
457
+ # output layers
458
+ h = self.norm(h)
459
+ logits = self.output(h)
460
+ # if self.training or self.is_sampling:
461
+ if cond_idx is not None:
462
+ # if self.training:
463
+ # logits = logits[:, self.cls_token_num - 1:].contiguous()
464
+ logits = logits[:, cond_idx.size(1) - 1:].contiguous()
465
+
466
+ # if we are given some desired targets also calculate the loss
467
+ loss = None
468
+ if valid is not None:
469
+ loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
470
+ valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
471
+ loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
472
+ elif targets is not None:
473
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
474
+ return logits, loss
475
+
476
+
477
+ @torch.inference_mode()
478
+ def sample(
479
+ self,
480
+ c,
481
+ cfg_scale=2.0,
482
+ cfg_interval=-1,
483
+ temperature=1.0,
484
+ top_k=0,
485
+ top_p=1.0,
486
+ seq_length=None,
487
+ ):
488
+ seq_length = self.max_seq_length if seq_length is None else seq_length
489
+ with self.sampling():
490
+ sampled_seqs = ar_generate(
491
+ self, c, seq_length,
492
+ cfg_scale=cfg_scale, cfg_interval=cfg_interval,
493
+ temperature=temperature, top_k=top_k,
494
+ top_p=top_p, sample_logits=True,
495
+ )
496
+ return sampled_seqs
497
+
498
+
499
+ @classmethod
500
+ def from_checkpoint(cls, ckpt, load_state_dict=True):
501
+ if isinstance(ckpt, str):
502
+ assert os.path.exists(ckpt), f"checkpoint {ckpt} does not exist"
503
+ ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
504
+ else:
505
+ assert isinstance(
506
+ ckpt, dict
507
+ ), f"checkpoint must be a dict or a path to a checkpoint"
508
+ model = models.make(ckpt["model"], load_sd=load_state_dict)
509
+ return model
510
+
511
+
512
+ #################################################################################
513
+ # LLAMA-ABS Configs #
514
+ #################################################################################
515
+
516
+ def LLAMA_ABS_XXXL(**kwargs):
517
+ return ARModel(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
518
+
519
+ def LLAMA_ABS_XXL(**kwargs):
520
+ return ARModel(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
521
+
522
+ def LLAMA_ABS_XL(**kwargs):
523
+ return ARModel(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
524
+
525
+ def LLAMA_ABS_LP(**kwargs):
526
+ return ARModel(ModelArgs(n_layer=30, n_head=20, dim=1280, **kwargs)) # 632M
527
+
528
+ def LLAMA_ABS_L(**kwargs):
529
+ return ARModel(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
530
+
531
+ def LLAMA_ABS_B(**kwargs):
532
+ return ARModel(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
533
+
534
+ def LLAMA_ABS_S(**kwargs):
535
+ return ARModel(ModelArgs(n_layer=12, n_head=6, dim=384, **kwargs)) # 21.7M
536
+
537
+ ar_models = {
538
+ 'llama-abs-S': LLAMA_ABS_S,
539
+ 'llama-abs-B': LLAMA_ABS_B,
540
+ 'llama-abs-L': LLAMA_ABS_L,
541
+ 'llama-abs-LP': LLAMA_ABS_LP,
542
+ 'llama-abs-XL': LLAMA_ABS_XL,
543
+ 'llama-abs-XXL': LLAMA_ABS_XXL,
544
+ 'llama-abs-XXXL': LLAMA_ABS_XXXL,
545
+ }
546
+
547
+ models.models.update(ar_models)
tok/ar_dtok/bottleneck.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange
19
+
20
+ from .. import models
21
+ from ..models import register
22
+
23
+
24
+ @register("bottleneck")
25
+ class Bottleneck(nn.Module):
26
+ def __init__(
27
+ self,
28
+ bottleneck_dim: int,
29
+ input_dim: int,
30
+ output_dim: int,
31
+ token_nums: int,
32
+ regularizer=None,
33
+ **kwargs
34
+ ):
35
+ super().__init__()
36
+ self.token_nums = token_nums
37
+ self.input_dim = input_dim
38
+ self.output_dim = output_dim
39
+ if bottleneck_dim > 0:
40
+ self.bottleneck_dim = bottleneck_dim
41
+ else:
42
+ assert self.input_dim == self.output_dim, "input_dim and output_dim must be the same when bottleneck_dim is not specified"
43
+ self.bottleneck_dim = self.input_dim
44
+
45
+ self.project_dim = self.bottleneck_dim
46
+
47
+ if self.bottleneck_dim > 0:
48
+ self.in_linear = nn.Linear(self.input_dim, self.project_dim)
49
+ self.out_linear = nn.Linear(self.bottleneck_dim, self.output_dim)
50
+ else:
51
+ self.in_linear = self.out_linear = lambda x: x
52
+
53
+ regularizer['args']['dim'] = self.bottleneck_dim
54
+ regularizer['args']['token_nums'] = self.token_nums
55
+ self.regularizer = models.make(regularizer)
56
+
57
+ def project_in(self, x):
58
+ assert len(x.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
59
+ z = self.in_linear(x)
60
+ return z
61
+
62
+ def project_out(self, z_cat):
63
+ z = self.out_linear(z_cat)
64
+ return z
65
+
66
+ def decode(self, bottleneck_rep):
67
+ regularized_z = self.regularizer.decode(bottleneck_rep)
68
+ return self.project_out(regularized_z)
69
+
70
+ def forward(self, x):
71
+ z = self.project_in(x)
72
+ projected_z = z
73
+ regularized_output = self.regularizer(z)
74
+ x_hat = self.project_out(regularized_output['regularized_z'])
75
+ bottleneck_rep = regularized_output.pop('bottleneck_rep')
76
+ return {
77
+ 'output': x_hat,
78
+ 'bottleneck_rep': bottleneck_rep,
79
+ 'projected_z': projected_z,
80
+ **regularized_output,
81
+ }
82
+
83
+
84
+ @register("simvq")
85
+ class SimVectorQuantizer(nn.Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ codebook_size,
90
+ l2_normalized=False,
91
+ same_index_shape=True,
92
+ stochastic=False,
93
+ stochastic_temperature=1.0,
94
+ **kwargs,
95
+ ):
96
+ super().__init__()
97
+ self.codebook_size = codebook_size
98
+ self.dim = dim
99
+ assert isinstance(l2_normalized, bool)
100
+ self.l2_normalized = l2_normalized
101
+ self.stochastic = stochastic
102
+ self.eval_deterministic = False
103
+ self.default_stochastic_temperature = stochastic_temperature
104
+
105
+ if self.stochastic:
106
+ if stochastic_temperature > 0: # fixed temperature
107
+ self.stochastic_temperature_inv = 1 / stochastic_temperature
108
+ else: # set stochastic_temperature < 0 to use learnable temperature
109
+ self.stochastic_temperature_inv = nn.Parameter(torch.tensor(10.0))
110
+
111
+ # for clear inference code, we remove the codebook init from LLM's embedding
112
+ self.embedding = nn.Embedding(self.codebook_size, self.dim)
113
+ self.embedding_proj = nn.Linear(self.dim, self.dim)
114
+
115
+ self.same_index_shape = same_index_shape
116
+
117
+ def set_eval_deterministic(self, deterministic=True):
118
+ self.eval_deterministic = deterministic
119
+
120
+ def set_stochastic_temperature(self, temperature):
121
+ self.stochastic_temperature_inv = 1 / temperature
122
+
123
+ @torch.autocast(device_type='cuda', enabled=False)
124
+ def get_emb(self):
125
+ emb = self.embedding_proj(self.embedding.weight)
126
+ if self.l2_normalized:
127
+ emb = F.normalize(emb, p=2, dim=-1)
128
+ # assert emb.dtype == torch.float32, f"Embedding weight dtype is {emb.dtype}, expected float32"
129
+ return emb
130
+
131
+ @torch.autocast(device_type='cuda', enabled=False)
132
+ def forward(self, z):
133
+ emb = self.get_emb()
134
+ z = z.to(emb)
135
+ # z = z.float()
136
+ assert len(z.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
137
+ if self.l2_normalized:
138
+ z = F.normalize(z, p=2, dim=-1)
139
+
140
+ z_flattened = rearrange(z, 'b n d -> (b n) d')
141
+
142
+ if self.stochastic:
143
+ # sample the softmaxed cosine similarity
144
+ assert self.l2_normalized, "Stochastic sampling requires l2 normalization"
145
+ cos_sim = torch.einsum("bd,nd->bn", z_flattened, emb)
146
+ probs = F.softmax(cos_sim * self.stochastic_temperature_inv, dim=-1)
147
+ if self.eval_deterministic and not self.training:
148
+ q_indices = torch.argmax(probs, dim=-1)
149
+ else:
150
+ q_indices = torch.multinomial(probs, 1).squeeze(-1)
151
+ else:
152
+ d = (
153
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
154
+ + torch.sum(emb**2, dim=1)
155
+ - 2
156
+ * torch.einsum(
157
+ "bd,dn->bn", z_flattened, rearrange(emb, "n d -> d n")
158
+ )
159
+ )
160
+ q_indices = torch.argmin(d, dim=1)
161
+
162
+ quantized = F.embedding(q_indices, emb, self.embedding.padding_idx, self.embedding.max_norm,
163
+ self.embedding.norm_type, self.embedding.scale_grad_by_freq, self.embedding.sparse).view(z.shape) # (b, n, d)
164
+
165
+ # preserve gradients
166
+ quantized = z + (quantized - z).detach()
167
+
168
+ if self.same_index_shape:
169
+ q_indices = q_indices.reshape(quantized.shape[0], quantized.shape[1])
170
+
171
+ return_dict = {
172
+ 'unregularized_z': z, # but l2 normalized if l2_normalized=True
173
+ 'emb': emb, # but l2 normalized if l2_normalized=True
174
+ 'regularized_z': quantized,
175
+ 'bottleneck_rep': q_indices
176
+ }
177
+ return return_dict
178
+
179
+ def get_codebook_entry(self, indices, shape=None):
180
+ # shape specifying (batch, height, width, channel)
181
+ indices_shape = indices.shape
182
+ indices_flatten = rearrange(indices, '... -> (...)')
183
+
184
+ # get quantized latent vectors
185
+ emb = self.get_emb()
186
+ z_q = F.embedding(indices_flatten, emb)
187
+ # z_q = self.embedding(indices_flatten)
188
+ if self.l2_normalized:
189
+ z_q = F.normalize(z_q, p=2, dim=-1)
190
+
191
+ if shape is not None:
192
+ z_q = z_q.reshape(shape)
193
+ else:
194
+ z_q = z_q.reshape([*indices_shape, self.dim])
195
+ return z_q
196
+
197
+ def decode(self, indices):
198
+ return self.get_codebook_entry(indices)
tok/ar_dtok/generate.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ # Modified from:
16
+ # llamagen: https://github.com/FoundationVision/LlamaGen/blob/main/autoregressive/models/generate.py
17
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
18
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
19
+
20
+
21
+ import torch
22
+ import torch._dynamo.config
23
+ import torch._inductor.config
24
+ from torch.nn import functional as F
25
+
26
+
27
+ ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
28
+ def top_k_top_p_filtering(
29
+ logits,
30
+ top_k: int = 0,
31
+ top_p: float = 1.0,
32
+ filter_value: float = -float("Inf"),
33
+ min_tokens_to_keep: int = 1,
34
+ ):
35
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
36
+ Args:
37
+ logits: logits distribution shape (batch size, vocabulary size)
38
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
39
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
40
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
41
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
42
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
43
+ """
44
+ if top_k > 0:
45
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
46
+ # Remove all tokens with a probability less than the last token of the top-k
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits[indices_to_remove] = filter_value
49
+
50
+ if top_p < 1.0:
51
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
52
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
53
+
54
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
55
+ sorted_indices_to_remove = cumulative_probs > top_p
56
+ if min_tokens_to_keep > 1:
57
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
58
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
59
+ # Shift the indices to the right to keep also the first token above the threshold
60
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
61
+ sorted_indices_to_remove[..., 0] = 0
62
+
63
+ # scatter sorted tensors to original indexing
64
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
65
+ logits[indices_to_remove] = filter_value
66
+ return logits
67
+
68
+
69
+ def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
70
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
71
+ if top_k > 0 or top_p < 1.0:
72
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
73
+
74
+ # improve numerical stability of softmax
75
+ probs = F.softmax(logits.float(), dim=-1)
76
+ if sample_logits:
77
+ idx = torch.multinomial(probs, num_samples=1)
78
+ else:
79
+ _, idx = torch.topk(probs, k=1, dim=-1)
80
+ return idx, probs
81
+
82
+
83
+ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
84
+ logits = logits / max(temperature, 1e-5)
85
+ if top_k > 0 or top_p < 1.0:
86
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
87
+ probs = torch.nn.functional.softmax(logits, dim=-1)
88
+ return probs
89
+
90
+
91
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, **sampling_kwargs):
92
+ if cfg_scale > 1.0:
93
+ logits, _ = model(None, cond_idx, input_pos)
94
+ logits_combined = logits
95
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
96
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
97
+ else:
98
+ logits, _ = model(None, cond_idx, input_pos)
99
+
100
+ return sample(logits, **sampling_kwargs)[0]
101
+
102
+
103
+ def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs):
104
+ assert input_pos.shape[-1] == 1
105
+ if cfg_scale > 1.0:
106
+ x_combined = torch.cat([x, x])
107
+ logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos)
108
+ logits_combined = logits
109
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
110
+ if cfg_flag:
111
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
112
+ else:
113
+ logits = cond_logits
114
+ else:
115
+ logits, _ = model(x, cond_idx=None, input_pos=input_pos)
116
+ return sample(logits, **sampling_kwargs)
117
+
118
+
119
+ def decode_n_tokens(
120
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
121
+ cfg_scale: float, cfg_interval: int,
122
+ **sampling_kwargs):
123
+ new_tokens, new_probs = [], []
124
+ cfg_flag = True
125
+ for i in range(num_new_tokens):
126
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
127
+ if cfg_interval > -1 and i > cfg_interval:
128
+ cfg_flag = False
129
+ next_token, next_prob = decode_one_token(
130
+ model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs
131
+ )
132
+ input_pos += 1
133
+ new_tokens.append(next_token.clone())
134
+ new_probs.append(next_prob.clone())
135
+ cur_token = next_token.view(-1, 1)
136
+
137
+ return new_tokens, new_probs
138
+
139
+
140
+ @torch.no_grad()
141
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs):
142
+ if model.frame_prediction:
143
+ assert cfg_scale == 1.0, "frame prediction requires cfg_scale=1.0 (no classifier-free guidance)"
144
+ cond_combined = cond
145
+ T = cond.shape[1]
146
+ elif model.model_type == 'class_cond':
147
+ if cfg_scale > 1.0:
148
+ cond_null = torch.ones_like(cond) * model.num_classes
149
+ cond_combined = torch.cat([cond, cond_null])
150
+ else:
151
+ cond_combined = cond
152
+ T = 1
153
+ elif model.model_type == 'clip_cond':
154
+ if cfg_scale > 1.0:
155
+ cond_null = torch.zeros_like(cond)
156
+ cond_combined = torch.cat([cond, cond_null])
157
+ else:
158
+ cond_combined = cond
159
+ T = model.cls_token_num
160
+ elif model.model_type == 'indice_cond':
161
+ if cfg_scale > 1.0:
162
+ cond_null = torch.ones_like(cond) * model.cond_vocab_size
163
+ cond_combined = torch.cat([cond, cond_null])
164
+ else:
165
+ cond_combined = cond
166
+ T = model.cls_token_num
167
+ else:
168
+ raise Exception("please check model type")
169
+
170
+ T_new = T + max_new_tokens
171
+ max_seq_length = T_new
172
+ max_batch_size = cond.shape[0]
173
+
174
+ device = cond.device
175
+ with torch.device(device):
176
+ max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
177
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
178
+
179
+ if emb_masks is not None:
180
+ assert emb_masks.shape[0] == max_batch_size
181
+ assert emb_masks.shape[-1] == T
182
+ if cfg_scale > 1.0:
183
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
184
+ else:
185
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
186
+
187
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
188
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
189
+
190
+ # create an empty tensor of the expected final shape and fill in the current tokens
191
+ seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
192
+
193
+ input_pos = torch.arange(0, T, device=device)
194
+
195
+ next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs)
196
+ seq[:, T:T+1] = next_token
197
+
198
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
199
+ generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs)
200
+ seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
201
+
202
+ return seq[:, T:]
tok/ar_dtok/vqvae.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import List
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ..models import register
23
+ from ..utils import ScalingLayer
24
+
25
+
26
+ @register('vqvae')
27
+ class VQVAE(nn.Module):
28
+ def __init__(
29
+ self,
30
+ model='VQ-16',
31
+ ckpt='',
32
+ codebook_size=16384,
33
+ codebook_embed_dim=8,
34
+ bottleneck_token_num=256,
35
+ input_size=256,
36
+ *args,
37
+ **kwargs,
38
+ ):
39
+ super().__init__()
40
+ self.codebook_size = codebook_size
41
+ self.codebook_embed_dim = codebook_embed_dim
42
+ self.bottleneck_token_num = bottleneck_token_num
43
+ self.input_size = input_size
44
+ self.model = VQ_models[model](
45
+ codebook_size=codebook_size,
46
+ codebook_embed_dim=codebook_embed_dim)
47
+ ckpt = torch.load(ckpt, map_location='cpu')
48
+ self.model.load_state_dict(ckpt['model'])
49
+ self.model.eval()
50
+
51
+ self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
52
+
53
+ @classmethod
54
+ def from_checkpoint(cls, ckpt, **kwargs):
55
+ model = cls(ckpt=ckpt, **kwargs)
56
+ return model
57
+
58
+ def decode_from_bottleneck(self, z):
59
+ if z.ndim == 2:
60
+ b = z.size(0)
61
+ h = w = int(z.size(-1) ** 0.5)
62
+ z = self.model.decode_code(z, (b, self.codebook_embed_dim, h, w))
63
+ return self.scale_layer.inv(z)
64
+
65
+
66
+ # Adapt from https://github.com/FoundationVision/LlamaGen/blob/main/tokenizer/tokenizer_image/vq_model.py
67
+ @dataclass
68
+ class ModelArgs:
69
+ codebook_size: int = 16384
70
+ codebook_embed_dim: int = 8
71
+ codebook_l2_norm: bool = True
72
+ codebook_show_usage: bool = True
73
+ commit_loss_beta: float = 0.25
74
+ entropy_loss_ratio: float = 0.0
75
+
76
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
77
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
78
+ z_channels: int = 256
79
+ dropout_p: float = 0.0
80
+
81
+
82
+ class VQModel(nn.Module):
83
+ def __init__(self, config: ModelArgs):
84
+ super().__init__()
85
+ self.config = config
86
+ self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
87
+ self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
88
+
89
+ self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
90
+ config.commit_loss_beta, config.entropy_loss_ratio,
91
+ config.codebook_l2_norm, config.codebook_show_usage)
92
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
93
+ self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
94
+
95
+ def encode(self, x):
96
+ h = self.encoder(x)
97
+ h = self.quant_conv(h)
98
+ quant, emb_loss, info = self.quantize(h)
99
+ return quant, emb_loss, info
100
+
101
+ def decode(self, quant):
102
+ quant = self.post_quant_conv(quant)
103
+ dec = self.decoder(quant)
104
+ return dec
105
+
106
+ def decode_code(self, code_b, shape=None, channel_first=True):
107
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
108
+ dec = self.decode(quant_b)
109
+ return dec
110
+
111
+ def forward(self, input):
112
+ quant, diff, _ = self.encode(input)
113
+ dec = self.decode(quant)
114
+ return dec, diff
115
+
116
+
117
+ class Encoder(nn.Module):
118
+ def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
119
+ norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
120
+ super().__init__()
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
124
+
125
+ # downsampling
126
+ in_ch_mult = (1,) + tuple(ch_mult)
127
+ self.conv_blocks = nn.ModuleList()
128
+ for i_level in range(self.num_resolutions):
129
+ conv_block = nn.Module()
130
+ # res & attn
131
+ res_block = nn.ModuleList()
132
+ attn_block = nn.ModuleList()
133
+ block_in = ch*in_ch_mult[i_level]
134
+ block_out = ch*ch_mult[i_level]
135
+ for _ in range(self.num_res_blocks):
136
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
137
+ block_in = block_out
138
+ if i_level == self.num_resolutions - 1:
139
+ attn_block.append(AttnBlock(block_in, norm_type))
140
+ conv_block.res = res_block
141
+ conv_block.attn = attn_block
142
+ # downsample
143
+ if i_level != self.num_resolutions-1:
144
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
145
+ self.conv_blocks.append(conv_block)
146
+
147
+ # middle
148
+ self.mid = nn.ModuleList()
149
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
150
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
151
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
152
+
153
+ # end
154
+ self.norm_out = Normalize(block_in, norm_type)
155
+ self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
156
+
157
+ def forward(self, x):
158
+ h = self.conv_in(x)
159
+ # downsampling
160
+ for i_level, block in enumerate(self.conv_blocks):
161
+ for i_block in range(self.num_res_blocks):
162
+ h = block.res[i_block](h)
163
+ if len(block.attn) > 0:
164
+ h = block.attn[i_block](h)
165
+ if i_level != self.num_resolutions - 1:
166
+ h = block.downsample(h)
167
+
168
+ # middle
169
+ for mid_block in self.mid:
170
+ h = mid_block(h)
171
+
172
+ # end
173
+ h = self.norm_out(h)
174
+ h = nonlinearity(h)
175
+ h = self.conv_out(h)
176
+ return h
177
+
178
+
179
+ class Decoder(nn.Module):
180
+ def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
181
+ dropout=0.0, resamp_with_conv=True, out_channels=3):
182
+ super().__init__()
183
+ self.num_resolutions = len(ch_mult)
184
+ self.num_res_blocks = num_res_blocks
185
+
186
+ block_in = ch*ch_mult[self.num_resolutions-1]
187
+ # z to block_in
188
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
189
+
190
+ # middle
191
+ self.mid = nn.ModuleList()
192
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
193
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
194
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
195
+
196
+ # upsampling
197
+ self.conv_blocks = nn.ModuleList()
198
+ for i_level in reversed(range(self.num_resolutions)):
199
+ conv_block = nn.Module()
200
+ # res & attn
201
+ res_block = nn.ModuleList()
202
+ attn_block = nn.ModuleList()
203
+ block_out = ch*ch_mult[i_level]
204
+ for _ in range(self.num_res_blocks + 1):
205
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
206
+ block_in = block_out
207
+ if i_level == self.num_resolutions - 1:
208
+ attn_block.append(AttnBlock(block_in, norm_type))
209
+ conv_block.res = res_block
210
+ conv_block.attn = attn_block
211
+ # downsample
212
+ if i_level != 0:
213
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
214
+ self.conv_blocks.append(conv_block)
215
+
216
+ # end
217
+ self.norm_out = Normalize(block_in, norm_type)
218
+ self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
219
+
220
+ @property
221
+ def last_layer(self):
222
+ return self.conv_out.weight
223
+
224
+ def forward(self, z):
225
+ # z to block_in
226
+ h = self.conv_in(z)
227
+
228
+ # middle
229
+ for mid_block in self.mid:
230
+ h = mid_block(h)
231
+
232
+ # upsampling
233
+ for i_level, block in enumerate(self.conv_blocks):
234
+ for i_block in range(self.num_res_blocks + 1):
235
+ h = block.res[i_block](h)
236
+ if len(block.attn) > 0:
237
+ h = block.attn[i_block](h)
238
+ if i_level != self.num_resolutions - 1:
239
+ h = block.upsample(h)
240
+
241
+ # end
242
+ h = self.norm_out(h)
243
+ h = nonlinearity(h)
244
+ h = self.conv_out(h)
245
+ return h
246
+
247
+
248
+ class VectorQuantizer(nn.Module):
249
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
250
+ super().__init__()
251
+ self.n_e = n_e
252
+ self.e_dim = e_dim
253
+ self.beta = beta
254
+ self.entropy_loss_ratio = entropy_loss_ratio
255
+ self.l2_norm = l2_norm
256
+ self.show_usage = show_usage
257
+
258
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
259
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
260
+ if self.l2_norm:
261
+ self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
262
+ if self.show_usage:
263
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
264
+
265
+ def forward(self, z):
266
+ # reshape z -> (batch, height, width, channel) and flatten
267
+ z = torch.einsum('b c h w -> b h w c', z).contiguous()
268
+ z_flattened = z.view(-1, self.e_dim)
269
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
270
+
271
+ if self.l2_norm:
272
+ z = F.normalize(z, p=2, dim=-1)
273
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
274
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
275
+ else:
276
+ embedding = self.embedding.weight
277
+
278
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
279
+ torch.sum(embedding**2, dim=1) - 2 * \
280
+ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
281
+
282
+ min_encoding_indices = torch.argmin(d, dim=1)
283
+ z_q = embedding[min_encoding_indices].view(z.shape)
284
+ perplexity = None
285
+ min_encodings = None
286
+ vq_loss = None
287
+ commit_loss = None
288
+ entropy_loss = None
289
+ codebook_usage = 0
290
+
291
+ if self.show_usage and self.training:
292
+ cur_len = min_encoding_indices.shape[0]
293
+ self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
294
+ self.codebook_used[-cur_len:] = min_encoding_indices
295
+ codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
296
+
297
+ # compute loss for embedding
298
+ if self.training:
299
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
300
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
301
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
302
+
303
+ # preserve gradients
304
+ z_q = z + (z_q - z).detach()
305
+
306
+ # reshape back to match original input shape
307
+ z_q = torch.einsum('b h w c -> b c h w', z_q)
308
+
309
+ return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
310
+
311
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
312
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
313
+ if self.l2_norm:
314
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
315
+ else:
316
+ embedding = self.embedding.weight
317
+ z_q = embedding[indices] # (b*h*w, c)
318
+
319
+ if shape is not None:
320
+ if channel_first:
321
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
322
+ # reshape back to match original input shape
323
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
324
+ else:
325
+ z_q = z_q.view(shape)
326
+ return z_q
327
+
328
+
329
+ class ResnetBlock(nn.Module):
330
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
331
+ super().__init__()
332
+ self.in_channels = in_channels
333
+ out_channels = in_channels if out_channels is None else out_channels
334
+ self.out_channels = out_channels
335
+ self.use_conv_shortcut = conv_shortcut
336
+
337
+ self.norm1 = Normalize(in_channels, norm_type)
338
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
339
+ self.norm2 = Normalize(out_channels, norm_type)
340
+ self.dropout = nn.Dropout(dropout)
341
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
342
+
343
+ if self.in_channels != self.out_channels:
344
+ if self.use_conv_shortcut:
345
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
346
+ else:
347
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
348
+
349
+ def forward(self, x):
350
+ h = x
351
+ h = self.norm1(h)
352
+ h = nonlinearity(h)
353
+ h = self.conv1(h)
354
+ h = self.norm2(h)
355
+ h = nonlinearity(h)
356
+ h = self.dropout(h)
357
+ h = self.conv2(h)
358
+
359
+ if self.in_channels != self.out_channels:
360
+ if self.use_conv_shortcut:
361
+ x = self.conv_shortcut(x)
362
+ else:
363
+ x = self.nin_shortcut(x)
364
+ return x+h
365
+
366
+
367
+ class AttnBlock(nn.Module):
368
+ def __init__(self, in_channels, norm_type='group'):
369
+ super().__init__()
370
+ self.norm = Normalize(in_channels, norm_type)
371
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
372
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
373
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
374
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
375
+
376
+ def forward(self, x):
377
+ h_ = x
378
+ h_ = self.norm(h_)
379
+ q = self.q(h_)
380
+ k = self.k(h_)
381
+ v = self.v(h_)
382
+
383
+ # compute attention
384
+ b,c,h,w = q.shape
385
+ q = q.reshape(b,c,h*w)
386
+ q = q.permute(0,2,1) # b,hw,c
387
+ k = k.reshape(b,c,h*w) # b,c,hw
388
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
389
+ w_ = w_ * (int(c)**(-0.5))
390
+ w_ = F.softmax(w_, dim=2)
391
+
392
+ # attend to values
393
+ v = v.reshape(b,c,h*w)
394
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
395
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
396
+ h_ = h_.reshape(b,c,h,w)
397
+
398
+ h_ = self.proj_out(h_)
399
+
400
+ return x+h_
401
+
402
+ def nonlinearity(x):
403
+ # swish
404
+ return x*torch.sigmoid(x)
405
+
406
+ def Normalize(in_channels, norm_type='group'):
407
+ assert norm_type in ['group', 'batch']
408
+ if norm_type == 'group':
409
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
410
+ elif norm_type == 'batch':
411
+ return nn.SyncBatchNorm(in_channels)
412
+
413
+
414
+ class Upsample(nn.Module):
415
+ def __init__(self, in_channels, with_conv):
416
+ super().__init__()
417
+ self.with_conv = with_conv
418
+ if self.with_conv:
419
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
420
+
421
+ def forward(self, x):
422
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
423
+ if self.with_conv:
424
+ x = self.conv(x)
425
+ return x
426
+
427
+
428
+ class Downsample(nn.Module):
429
+ def __init__(self, in_channels, with_conv):
430
+ super().__init__()
431
+ self.with_conv = with_conv
432
+ if self.with_conv:
433
+ # no asymmetric padding in torch conv, must do it ourselves
434
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
435
+
436
+ def forward(self, x):
437
+ if self.with_conv:
438
+ pad = (0,1,0,1)
439
+ x = F.pad(x, pad, mode="constant", value=0)
440
+ x = self.conv(x)
441
+ else:
442
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
443
+ return x
444
+
445
+
446
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
447
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
448
+ flat_affinity /= temperature
449
+ probs = F.softmax(flat_affinity, dim=-1)
450
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
451
+ if loss_type == "softmax":
452
+ target_probs = probs
453
+ else:
454
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
455
+ avg_probs = torch.mean(target_probs, dim=0)
456
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
457
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
458
+ loss = sample_entropy - avg_entropy
459
+ return loss
460
+
461
+
462
+ #################################################################################
463
+ # VQ Model Configs #
464
+ #################################################################################
465
+ def VQ_8(**kwargs):
466
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
467
+
468
+ def VQ_16(**kwargs):
469
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
470
+
471
+ VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
tok/mm_autoencoder.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from tok.ar_dtok.ar_model import ARModel
19
+ from tok.ar_dtok.vqvae import VQVAE
20
+ from tok.ta_tok import TextAlignedTokenizer
21
+
22
+
23
+ class MMAutoEncoder(nn.Module):
24
+ def __init__(self,
25
+ ar_path_dict,
26
+ encoder_path, decoder_path,
27
+ encoder_args={}, decoder_args={}):
28
+ super().__init__()
29
+ self.ar_model = nn.ModuleDict({resolution: ARModel.from_checkpoint(ar_path) for resolution, ar_path in ar_path_dict.items()})
30
+
31
+ self.encoder = TextAlignedTokenizer.from_checkpoint(encoder_path, load_teacher=False, **encoder_args)
32
+ self.decoder = VQVAE.from_checkpoint(decoder_path, **decoder_args)
33
+
34
+ def ar_sample(self, x, args):
35
+ resolution = args.get("resolution", "1024px")
36
+ x = self.ar_model[resolution].sample(
37
+ x,
38
+ cfg_scale=args.get('cfg_scale', 1.0),
39
+ cfg_interval=args.get('cfg_interval', -1),
40
+ temperature=args.get('temperature', 1.0),
41
+ top_k=args.get('top_k', 0),
42
+ top_p=args.get('top_p', 1.0)
43
+ )
44
+ return x
45
+
46
+ def post_process(self, x):
47
+ x = x.cpu().float().clamp(0., 1.) * 255.
48
+ x = x.permute(0, 2, 3, 1) # [b, h, w, c]
49
+ x = x.to(torch.uint8)
50
+ return x
51
+
52
+ def encode(self, x):
53
+ return self.encoder(x.to(self.encoder.dtype))['encoded']
54
+
55
+ def get_encoder_indices(self, x):
56
+ # img -> encoder -> indices
57
+ return self.encoder(x.to(self.encoder.dtype))['bottleneck_rep']
58
+
59
+ @torch.inference_mode()
60
+ def decode_from_encoder_indices(self, indices, args={}):
61
+ # indices -> encoder feats -> ar -> decoder
62
+ encoder_x = self.encoder.decode_from_bottleneck(indices)
63
+ ar_indices = self.ar_sample(encoder_x, args)
64
+ decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
65
+ x = self.post_process(decoder_x)
66
+ return x
67
+
68
+ def decode_from_vqvae_indices(self, indices):
69
+ decoder_x = self.decoder.decode_from_bottleneck(indices)
70
+ x = self.post_process(decoder_x)
71
+ return x
72
+
73
+ @torch.inference_mode()
74
+ def forward(self, x, args={}):
75
+ encoder_x = self.encoder(x.to(self.encoder.dtype))['encoded']
76
+ ar_indices = self.ar_sample(encoder_x, args)
77
+ decoder_x = self.decoder.decode_from_bottleneck(ar_indices)
78
+ x = self.post_process(decoder_x)
79
+ return x
tok/models.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+
18
+ import torch
19
+
20
+ models = {}
21
+
22
+
23
+ def register(name):
24
+ def decorator(cls):
25
+ models[name] = cls
26
+ return cls
27
+ return decorator
28
+
29
+
30
+ def make(model_spec, args=None, load_sd=False) -> torch.nn.Module:
31
+ if args is not None:
32
+ model_args = copy.deepcopy(model_spec['args'])
33
+ model_args.update(args)
34
+ else:
35
+ model_args = model_spec['args']
36
+ model_params = inspect.signature(models[model_spec['name']]).parameters
37
+ if 'kwargs' not in model_params:
38
+ model_args = {k: v for k, v in model_args.items() if k in model_params}
39
+ model = models[model_spec['name']](**model_args)
40
+ if load_sd:
41
+ if ('abs_pe' in model_spec['sd']) and hasattr(model, 'abs_pe') and model_spec['sd']['abs_pe'].shape != model.abs_pe.shape:
42
+ del model_spec['sd']['abs_pe']
43
+ msg = model.load_state_dict(model_spec['sd'], strict=False)
44
+ print(msg)
45
+ return model
tok/ta_tok.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange
19
+ from torchvision.transforms import Resize
20
+ from transformers import AutoConfig, AutoModel, Siglip2VisionConfig, Siglip2VisionModel
21
+
22
+ from . import models
23
+ from .utils import ScalingLayer
24
+
25
+
26
+ class TextAlignedTokenizer(nn.Module):
27
+ def __init__(
28
+ self,
29
+ bottleneck,
30
+ bottleneck_token_num=256,
31
+ input_size=384,
32
+ teacher='google/siglip2-so400m-patch14-384',
33
+ input_type='quant', # choose from ['quant', 'rec', 'indices']
34
+ pool_scale=1, # choose from [1, 2, 3]
35
+ decoder_depth=3,
36
+ select_layer_id=-2,
37
+ *args,
38
+ **kwargs
39
+ ):
40
+ super().__init__()
41
+ self.input_size = input_size
42
+ self.bottleneck_token_num = bottleneck_token_num
43
+ self.teacher = teacher
44
+ self.input_type = input_type
45
+ self.pool_scale = pool_scale
46
+ self.decoder_depth = decoder_depth
47
+ self.select_layer_id = select_layer_id
48
+
49
+ self.bottleneck_dim = bottleneck['args']['bottleneck_dim']
50
+
51
+ self.encoder_config = AutoConfig.from_pretrained(teacher)
52
+ self.encoder = AutoModel.from_config(self.encoder_config).vision_model
53
+
54
+ self.encoder_hidden_dim = self.encoder.config.hidden_size
55
+
56
+ self.decoder_config = Siglip2VisionConfig()
57
+ self.decoder_config.update({
58
+ 'patch_size': 1,
59
+ 'num_hidden_layers': self.decoder_depth,
60
+ 'num_channels': self.bottleneck_dim,
61
+ 'hidden_size': self.encoder_hidden_dim,
62
+ })
63
+ self.decoder = Siglip2VisionModel(self.decoder_config)
64
+
65
+ self.encode_task_layer = nn.Sequential(
66
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
67
+ nn.Tanh())
68
+ self.decode_task_layer = nn.Sequential(
69
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
70
+ nn.Tanh(),
71
+ nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim))
72
+
73
+ bottleneck_args = {
74
+ 'token_nums': self.bottleneck_token_num,
75
+ 'input_dim': self.encoder_hidden_dim,
76
+ 'output_dim': self.bottleneck_dim}
77
+ self.bottleneck = models.make(bottleneck, args=bottleneck_args)
78
+
79
+ self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
80
+ self.image_resize = Resize((self.input_size, self.input_size))
81
+
82
+ def set_vq_eval_deterministic(self, deterministic=True):
83
+ self.bottleneck.regularizer.set_eval_deterministic(deterministic)
84
+
85
+ @property
86
+ def device(self):
87
+ return next(self.parameters()).device
88
+
89
+ @property
90
+ def dtype(self):
91
+ return next(self.parameters()).dtype
92
+
93
+ @classmethod
94
+ def from_checkpoint(cls, ckpt, load_teacher=True, **kwargs):
95
+ ckpt = torch.load(ckpt, map_location='cpu')
96
+ ckpt_kwargs = ckpt["model"]["args"]
97
+ model = cls(**kwargs, **ckpt_kwargs)
98
+ sd = ckpt["model"]["sd"]
99
+ if not load_teacher:
100
+ sd = {k: v for k, v in sd.items() if not k.startswith('teacher')}
101
+ model.load_state_dict(sd, strict=True)
102
+ return model
103
+
104
+ def encode(self, x, **kwargs):
105
+ if x.ndim == 5:
106
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
107
+ x = self.scale_layer(x)
108
+ if tuple(x.shape[-2:]) != (self.input_size, self.input_size):
109
+ x = self.image_resize(x)
110
+ vq_feats = self.encoder(x, output_hidden_states=True).hidden_states[self.select_layer_id]
111
+
112
+ pool_scale = self.pool_scale
113
+ pool_scale = kwargs.get("pool_scale", pool_scale)
114
+ if pool_scale != 1:
115
+ vq_feats = self.avg_pool(vq_feats, pool_scale)
116
+ vq_feats = self.encode_task_layer(vq_feats.to(x))
117
+
118
+ bottleneck_out = self.bottleneck(vq_feats)
119
+ z = bottleneck_out.pop('output')
120
+
121
+ return {'encoded': z, 'pool_scale': pool_scale, 'vq_feats': vq_feats, **bottleneck_out}
122
+
123
+ def avg_pool(self, z, pool_scale=1):
124
+ if z.ndim == 3:
125
+ b, n, c = z.shape
126
+ p = int(n ** 0.5)
127
+ z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p)
128
+ else:
129
+ b, c, p, _ = z.shape
130
+ p_s = int(p // pool_scale)
131
+ z = F.avg_pool2d(
132
+ z,
133
+ kernel_size=(pool_scale, pool_scale),
134
+ stride=(pool_scale, pool_scale)
135
+ ).contiguous()
136
+ z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c')
137
+ return z
138
+
139
+ def decode(self, z):
140
+ if z.ndim == 4:
141
+ z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c')
142
+ attention_mask = torch.ones(z.shape[:2], dtype=torch.int, device=z.device)
143
+ p = int(z.shape[1]**0.5)
144
+ spatial_shape = torch.tensor([[p, p]]*z.shape[0], device=self.device)
145
+ z = self.decoder(z, attention_mask, spatial_shape, output_hidden_states=True).last_hidden_state
146
+ z = self.decode_task_layer(z)
147
+ return z
148
+
149
+ def decode_from_bottleneck(self, bottleneck_rep):
150
+ z = self.bottleneck.decode(bottleneck_rep) # (b, n, c)
151
+ p = int(z.shape[1]**0.5)
152
+ z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p)
153
+ return self.decode(z)
154
+
155
+ def forward(self, data, **kwargs):
156
+ # data: video in shape (b, c, t, h, w)
157
+ encode_output = self.encode(data, **kwargs)
158
+ vq_feats = encode_output['encoded']
159
+ p = int(vq_feats.shape[1] ** 0.5)
160
+ vq_feats = rearrange(vq_feats, 'b (h w) c -> b c h w', h=p, w=p)
161
+ pred_feats = self.decode(vq_feats)
162
+
163
+ if self.input_type == 'quant':
164
+ z = encode_output["regularized_z"] # [b, n, c]
165
+ elif self.input_type == 'indices':
166
+ z = encode_output["bottleneck_rep"] # [b, n]
167
+ elif self.input_type == 'rec':
168
+ z = pred_feats # [b, n, c]
169
+ encode_output['encoded'] = z
170
+ return encode_output
tok/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ class ScalingLayer(nn.Module):
20
+ def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
21
+ super().__init__()
22
+ self.register_buffer('shift', torch.Tensor(mean)[None, :, None, None])
23
+ self.register_buffer('scale', torch.Tensor(std)[None, :, None, None])
24
+
25
+ def forward(self, inp):
26
+ return (inp - self.shift) / self.scale
27
+
28
+ def inv(self, inp):
29
+ return inp * self.scale + self.shift