Spaces:
Runtime error
Runtime error
Add all of `fourm`
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- fourm/data/__init__.py +4 -0
- fourm/data/dataset_utils.py +85 -0
- fourm/data/image_augmenter.py +186 -0
- fourm/data/masking.py +747 -0
- fourm/data/modality_info.py +427 -0
- fourm/data/modality_transforms.py +1387 -0
- fourm/data/multimodal_dataset_folder.py +363 -0
- fourm/data/pretrain_utils.py +292 -0
- fourm/data/transfer_utils.py +53 -0
- fourm/data/unified_datasets.py +557 -0
- fourm/demo_4M_sampler.py +540 -0
- fourm/models/__init__.py +0 -0
- fourm/models/decoder_embeddings.py +268 -0
- fourm/models/encoder_embeddings.py +422 -0
- fourm/models/fm.py +1130 -0
- fourm/models/fm_utils.py +387 -0
- fourm/models/fm_vit.py +485 -0
- fourm/models/generate.py +1273 -0
- fourm/models/lora_utils.py +177 -0
- fourm/utils/__init__.py +22 -0
- fourm/utils/checkpoint.py +185 -0
- fourm/utils/clip/__init__.py +2 -0
- fourm/utils/clip/clip.py +236 -0
- fourm/utils/clip/model.py +504 -0
- fourm/utils/clip/simple_tokenizer.py +137 -0
- fourm/utils/data_constants.py +40 -0
- fourm/utils/dist.py +100 -0
- fourm/utils/fsdp_utils.py +116 -0
- fourm/utils/generation.py +99 -0
- fourm/utils/generation_datasets/PartiPrompts.tsv +0 -0
- fourm/utils/generation_datasets/__init__.py +3 -0
- fourm/utils/generation_datasets/empty_dataset.py +27 -0
- fourm/utils/generation_datasets/image_caption_dataset.py +99 -0
- fourm/utils/generation_datasets/parti_prompts_dataset.py +114 -0
- fourm/utils/hmr2_utils/hmr2/__init__.py +0 -0
- fourm/utils/hmr2_utils/hmr2/models/__init__.py +2 -0
- fourm/utils/hmr2_utils/hmr2/models/backbones/__init__.py +13 -0
- fourm/utils/hmr2_utils/hmr2/models/backbones/vit.py +353 -0
- fourm/utils/hmr2_utils/hmr2/models/components/__init__.py +0 -0
- fourm/utils/hmr2_utils/hmr2/models/components/pose_transformer.py +363 -0
- fourm/utils/hmr2_utils/hmr2/models/components/t_cond_mlp.py +204 -0
- fourm/utils/hmr2_utils/hmr2/models/heads/__init__.py +1 -0
- fourm/utils/hmr2_utils/hmr2/models/heads/smpl_head.py +116 -0
- fourm/utils/hmr2_utils/hmr2/models/hmr2.py +117 -0
- fourm/utils/hmr2_utils/hmr2/models/smpl_wrapper.py +47 -0
- fourm/utils/hmr2_utils/hmr2/utils/__init__.py +31 -0
- fourm/utils/hmr2_utils/hmr2/utils/geometry.py +109 -0
- fourm/utils/hmr2_utils/hmr2/utils/mesh_renderer.py +155 -0
- fourm/utils/hmr2_utils/hmr2/utils/render_openpose.py +155 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
fourm/utils/clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
|
fourm/data/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .image_augmenter import *
|
2 |
+
from .modality_transforms import *
|
3 |
+
from .unified_datasets import build_fm_pretraining_dataset, build_fm_transfer_dataset, build_wds_fm_pretraining_dataloader, build_wds_divae_dataloader, build_huggingface_pretraining_dataloader, build_mixture_dataloader
|
4 |
+
from .pretrain_utils import *
|
fourm/data/dataset_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
|
17 |
+
|
18 |
+
class RepeatedDatasetWrapper(Dataset):
|
19 |
+
def __init__(self, original_dataset, num_repeats):
|
20 |
+
"""
|
21 |
+
Dataset wrapper that repeats the original dataset n times.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
original_dataset (torch.utils.data.Dataset): The original dataset to be repeated.
|
25 |
+
num_repeats (int): The number of times the dataset should be repeated.
|
26 |
+
"""
|
27 |
+
self.original_dataset = original_dataset
|
28 |
+
self.num_repeats = num_repeats
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
"""
|
32 |
+
Retrieve the item at the given index.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
index (int): The index of the item to be retrieved.
|
36 |
+
"""
|
37 |
+
original_index = index % len(self.original_dataset)
|
38 |
+
return self.original_dataset[original_index]
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
"""
|
42 |
+
Get the length of the dataset after repeating it n times.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
int: The length of the dataset.
|
46 |
+
"""
|
47 |
+
return len(self.original_dataset) * self.num_repeats
|
48 |
+
|
49 |
+
|
50 |
+
class SubsampleDatasetWrapper(Dataset):
|
51 |
+
def __init__(self, original_dataset, dataset_size, seed=0, return_orig_idx=False):
|
52 |
+
"""
|
53 |
+
Dataset wrapper that randomly subsamples the original dataset.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
original_dataset (torch.utils.data.Dataset): The original dataset to be subsampled.
|
57 |
+
dataset_size (int): The size of the subsampled dataset.
|
58 |
+
seed (int): The seed to use for selecting the subset of indices of the original dataset.
|
59 |
+
return_orig_idx (bool): Whether to return the original index of the item in the original dataset.
|
60 |
+
"""
|
61 |
+
self.original_dataset = original_dataset
|
62 |
+
self.dataset_size = dataset_size or len(original_dataset)
|
63 |
+
self.return_orig_idx = return_orig_idx
|
64 |
+
np.random.seed(seed)
|
65 |
+
self.indices = np.random.permutation(len(self.original_dataset))[:self.dataset_size]
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
"""
|
69 |
+
Retrieve the item at the given index.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
index (int): The index of the item to be retrieved.
|
73 |
+
"""
|
74 |
+
original_index = self.indices[index]
|
75 |
+
sample = self.original_dataset[original_index]
|
76 |
+
return sample, original_index if self.return_orig_idx else sample
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
"""
|
80 |
+
Get the length of the dataset after subsampling it.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
int: The length of the dataset.
|
84 |
+
"""
|
85 |
+
return len(self.indices)
|
fourm/data/image_augmenter.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import random
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torchvision
|
19 |
+
|
20 |
+
from fourm.utils import to_2tuple
|
21 |
+
|
22 |
+
|
23 |
+
class AbstractImageAugmenter(ABC):
|
24 |
+
"""Abstract class for image augmenters.
|
25 |
+
"""
|
26 |
+
|
27 |
+
@abstractmethod
|
28 |
+
def __call__(self, mod_dict, crop_settings):
|
29 |
+
pass
|
30 |
+
|
31 |
+
|
32 |
+
class RandomCropImageAugmenter(AbstractImageAugmenter):
|
33 |
+
|
34 |
+
def __init__(self, target_size=224, hflip=0.5, crop_scale=(0.2, 1.0), crop_ratio=(0.75, 1.3333), main_domain='rgb'):
|
35 |
+
|
36 |
+
self.target_size = to_2tuple(target_size)
|
37 |
+
self.hflip = hflip
|
38 |
+
self.crop_scale = crop_scale
|
39 |
+
self.crop_ratio = crop_ratio
|
40 |
+
self.main_domain = main_domain
|
41 |
+
|
42 |
+
def __call__(self, mod_dict, crop_settings):
|
43 |
+
|
44 |
+
if crop_settings is not None:
|
45 |
+
raise ValueError("Crop settings are provided but not used by this augmenter.")
|
46 |
+
|
47 |
+
image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
|
48 |
+
# With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
|
49 |
+
orig_width, orig_height = image.size
|
50 |
+
orig_size = (orig_height, orig_width)
|
51 |
+
|
52 |
+
top, left, h, w = torchvision.transforms.RandomResizedCrop.get_params(
|
53 |
+
image, scale=self.crop_scale, ratio=self.crop_ratio
|
54 |
+
)
|
55 |
+
crop_coords = top, left, h, w
|
56 |
+
flip = random.random() < self.hflip
|
57 |
+
rand_aug_idx = None
|
58 |
+
|
59 |
+
return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
|
60 |
+
|
61 |
+
class NoImageAugmenter(AbstractImageAugmenter): # this is for non-image modalities like poses where we don't do any augs, e.g. during tokenization
|
62 |
+
|
63 |
+
def __init__(self, no_aug=True, main_domain='human_poses'):
|
64 |
+
self.target_size = None #to_2tuple(target_size)
|
65 |
+
self.no_aug = no_aug
|
66 |
+
self.main_domain = main_domain
|
67 |
+
|
68 |
+
def __call__(self, mod_dict, crop_settings):
|
69 |
+
# # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
|
70 |
+
orig_size = (224, 224)
|
71 |
+
|
72 |
+
rand_aug_idx = 0
|
73 |
+
top, left, h, w, flip = 0, 0, 224, 224, 0
|
74 |
+
crop_coords = (top, left, h, w)
|
75 |
+
|
76 |
+
return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
|
77 |
+
|
78 |
+
class PreTokenizedImageAugmenter(AbstractImageAugmenter):
|
79 |
+
|
80 |
+
def __init__(self, target_size, no_aug=False, main_domain='rgb'):
|
81 |
+
self.target_size = to_2tuple(target_size)
|
82 |
+
self.no_aug = no_aug
|
83 |
+
self.main_domain = main_domain
|
84 |
+
|
85 |
+
def __call__(self, mod_dict, crop_settings):
|
86 |
+
# With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
|
87 |
+
if self.main_domain in mod_dict and 'tok' not in self.main_domain:
|
88 |
+
image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
|
89 |
+
orig_width, orig_height = image.size
|
90 |
+
orig_size = (orig_height, orig_width)
|
91 |
+
else:
|
92 |
+
orig_size = None
|
93 |
+
|
94 |
+
rand_aug_idx = 0 if self.no_aug else np.random.randint(len(crop_settings))
|
95 |
+
top, left, h, w, flip = crop_settings[rand_aug_idx]
|
96 |
+
crop_coords = (top, left, h, w)
|
97 |
+
|
98 |
+
return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
|
99 |
+
|
100 |
+
|
101 |
+
class CenterCropImageAugmenter(AbstractImageAugmenter):
|
102 |
+
def __init__(self, target_size, hflip=0.0, main_domain='rgb'):
|
103 |
+
self.target_size = to_2tuple(target_size)
|
104 |
+
self.hflip = hflip
|
105 |
+
self.main_domain = main_domain
|
106 |
+
|
107 |
+
def __call__(self, mod_dict, crop_settings=None):
|
108 |
+
image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
|
109 |
+
orig_width, orig_height = image.size
|
110 |
+
orig_size = (orig_height, orig_width)
|
111 |
+
|
112 |
+
if orig_height > orig_width:
|
113 |
+
h = w = orig_width
|
114 |
+
top = (orig_height - orig_width) // 2
|
115 |
+
left = 0
|
116 |
+
else:
|
117 |
+
h = w = orig_height
|
118 |
+
top = 0
|
119 |
+
left = (orig_width - orig_height) // 2
|
120 |
+
|
121 |
+
crop_coords = (top, left, h, w)
|
122 |
+
flip = random.random() < self.hflip
|
123 |
+
rand_aug_idx = None
|
124 |
+
|
125 |
+
return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
|
126 |
+
|
127 |
+
|
128 |
+
class PaddingImageAugmenter(AbstractImageAugmenter):
|
129 |
+
def __init__(self, target_size, hflip=0.0, main_domain='rgb'):
|
130 |
+
self.target_size = to_2tuple(target_size)
|
131 |
+
self.hflip = hflip
|
132 |
+
self.main_domain = main_domain
|
133 |
+
|
134 |
+
def __call__(self, mod_dict, crop_settings):
|
135 |
+
image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
|
136 |
+
orig_width, orig_height = image.size
|
137 |
+
orig_size = (orig_height, orig_width)
|
138 |
+
|
139 |
+
h = w = max(orig_width, orig_height)
|
140 |
+
top = left = 0
|
141 |
+
crop_coords = (top, left, h, w)
|
142 |
+
flip = random.random() < self.hflip
|
143 |
+
rand_aug_idx = None
|
144 |
+
|
145 |
+
return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
|
146 |
+
|
147 |
+
|
148 |
+
class ScaleJitteringImageAugmenter(AbstractImageAugmenter):
|
149 |
+
def __init__(self, target_size, hflip=0.0, scale=(0.1, 2.0), main_domain='rgb'):
|
150 |
+
self.target_size = to_2tuple(target_size)
|
151 |
+
self.hflip = hflip
|
152 |
+
self.scale = scale
|
153 |
+
self.main_domain = main_domain
|
154 |
+
|
155 |
+
def scale_jitter(self, orig_height, orig_width):
|
156 |
+
rand_scale = np.random.uniform(self.scale[0], self.scale[1])
|
157 |
+
max_hw = max(orig_height, orig_width)
|
158 |
+
h = w = round(max_hw / rand_scale)
|
159 |
+
top = round(max(0, np.random.uniform(0, orig_height - h)))
|
160 |
+
left = round(max(0, np.random.uniform(0, orig_width - w)))
|
161 |
+
|
162 |
+
return top, left, h, w
|
163 |
+
|
164 |
+
def __call__(self, mod_dict, crop_settings):
|
165 |
+
|
166 |
+
if crop_settings is not None:
|
167 |
+
raise ValueError("Crop settings are provided but not used by this augmenter.")
|
168 |
+
|
169 |
+
image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
|
170 |
+
# With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
|
171 |
+
orig_width, orig_height = image.size
|
172 |
+
orig_size = (orig_height, orig_width)
|
173 |
+
|
174 |
+
crop_coords = self.scale_jitter(orig_height, orig_width)
|
175 |
+
flip = random.random() < self.hflip
|
176 |
+
rand_aug_idx = None
|
177 |
+
|
178 |
+
return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
|
179 |
+
|
180 |
+
|
181 |
+
class EmptyAugmenter(AbstractImageAugmenter):
|
182 |
+
def __init__(self):
|
183 |
+
pass
|
184 |
+
|
185 |
+
def __call__(self, mod_dict, crop_settings):
|
186 |
+
return None, None, None, None, None
|
fourm/data/masking.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
import random
|
16 |
+
from typing import Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from einops import rearrange
|
22 |
+
from tokenizers import Tokenizer
|
23 |
+
from torch.distributions import Dirichlet
|
24 |
+
|
25 |
+
from fourm.data.modality_transforms import get_transform_key
|
26 |
+
from fourm.utils import to_2tuple
|
27 |
+
from fourm.utils.tokenizer import get_sentinel_to_id_mapping
|
28 |
+
|
29 |
+
|
30 |
+
def sample_cosine(min_val: float = 0, max_val: float =1) -> float:
|
31 |
+
"""Sample a value from a cosine distribution between min_val and max_val
|
32 |
+
|
33 |
+
Args:
|
34 |
+
min_val: Minimum value
|
35 |
+
max_val: Maximum value
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Sampled value
|
39 |
+
"""
|
40 |
+
|
41 |
+
return min_val + 0.5 * (max_val - min_val) * (1 + math.cos(math.pi * random.uniform(0, 1)))
|
42 |
+
|
43 |
+
|
44 |
+
def sample_uniform(min_val: float = 0, max_val: float =1) -> float:
|
45 |
+
"""Sample a value from a uniform distribution between min_val and max_val
|
46 |
+
|
47 |
+
Args:
|
48 |
+
min_val: Minimum value
|
49 |
+
max_val: Maximum value
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Sampled value
|
53 |
+
"""
|
54 |
+
|
55 |
+
return random.uniform(min_val, max_val)
|
56 |
+
|
57 |
+
|
58 |
+
def simple_span_masking(sequence: List[int], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
|
59 |
+
"""Span masking for a sequence
|
60 |
+
|
61 |
+
Args:
|
62 |
+
sequence: Sequence to mask
|
63 |
+
sentinel_to_id: Mapping from sentinel to id
|
64 |
+
keep_prob: Probability of keeping a token
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Masked input sequence and masked target sequence
|
68 |
+
"""
|
69 |
+
sequence_length = len(sequence)
|
70 |
+
# 0 for keep, 1 for mask
|
71 |
+
masks = torch.where(torch.rand(sequence_length) <= keep_prob, 0, 1).bool().tolist()
|
72 |
+
|
73 |
+
input_sequence = []
|
74 |
+
target_sequence = []
|
75 |
+
|
76 |
+
prev_mask = False
|
77 |
+
sentinel_count = 0
|
78 |
+
for token, mask in zip(sequence, masks):
|
79 |
+
if mask:
|
80 |
+
if not prev_mask:
|
81 |
+
sentinel_count += 1
|
82 |
+
input_sequence.append(sentinel_to_id[sentinel_count])
|
83 |
+
target_sequence.append(sentinel_to_id[sentinel_count])
|
84 |
+
prev_mask = True
|
85 |
+
target_sequence.append(token)
|
86 |
+
else:
|
87 |
+
prev_mask = False
|
88 |
+
input_sequence.append(token)
|
89 |
+
|
90 |
+
target_sequence.append(sentinel_to_id[sentinel_count + 1])
|
91 |
+
return input_sequence, target_sequence
|
92 |
+
|
93 |
+
|
94 |
+
def chunk_span_masking(sequence_chunks: List[List[int]], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
|
95 |
+
"""Span masking where masking is performed at the chunk level.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
sequence_chunks: Sequence chunks to mask
|
99 |
+
sentinel_to_id: Mapping from sentinel to id
|
100 |
+
keep_prob: Probability of keeping a token
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Masked input sequence and masked target sequence
|
104 |
+
"""
|
105 |
+
chunk_length = len(sequence_chunks)
|
106 |
+
# 0 for keep, 1 for mask
|
107 |
+
masks = torch.where(torch.rand(chunk_length) <= keep_prob, 0, 1).bool().tolist()
|
108 |
+
|
109 |
+
input_sequence = []
|
110 |
+
target_sequence = []
|
111 |
+
|
112 |
+
prev_mask = False
|
113 |
+
sentinel_count = 0
|
114 |
+
for chunk, mask in zip(sequence_chunks, masks):
|
115 |
+
if mask:
|
116 |
+
if not prev_mask:
|
117 |
+
sentinel_count += 1
|
118 |
+
input_sequence.append(sentinel_to_id[sentinel_count])
|
119 |
+
target_sequence.append(sentinel_to_id[sentinel_count])
|
120 |
+
prev_mask = True
|
121 |
+
target_sequence.extend(chunk)
|
122 |
+
else:
|
123 |
+
prev_mask = False
|
124 |
+
input_sequence.extend(chunk)
|
125 |
+
|
126 |
+
target_sequence.append(sentinel_to_id[sentinel_count + 1])
|
127 |
+
return input_sequence, target_sequence
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
class UnifiedMasking(object):
|
132 |
+
def __init__(self,
|
133 |
+
modality_info: Dict,
|
134 |
+
text_tokenizer: Optional[Tokenizer],
|
135 |
+
input_tokens_range: Union[int, Tuple[int, int]],
|
136 |
+
target_tokens_range: Optional[Union[int, Tuple[int, int]]],
|
137 |
+
max_tries: int = 100,
|
138 |
+
sampling_weights: Optional[List[float]] = None,):
|
139 |
+
"""Performs masking on a dict of modalities (both image based and sequence based modalities)
|
140 |
+
|
141 |
+
Args:
|
142 |
+
modality_info: Dict with the modalities and their corresponding information
|
143 |
+
text_tokenizer: Tokenizer to use for text modalities
|
144 |
+
input_tokens_range: Range of number of tokens to mask in the input
|
145 |
+
target_tokens_range: Range of number of tokens to mask in the target
|
146 |
+
max_tries: Maximum number of tries to find a valid token budgets
|
147 |
+
sampling_weights: Sampling weights for the mixture of Dirichlet distributions
|
148 |
+
"""
|
149 |
+
self.input_tokens_range = to_2tuple(input_tokens_range)
|
150 |
+
self.target_tokens_range = to_2tuple(target_tokens_range) if target_tokens_range is not None else None
|
151 |
+
self.modality_info = modality_info
|
152 |
+
self.num_modalities = len(modality_info)
|
153 |
+
self.max_tries = max_tries
|
154 |
+
self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
|
155 |
+
self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
|
156 |
+
self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])
|
157 |
+
|
158 |
+
# Dirichlet sampling (supports a mixture of multiple Dirichlet distributions)
|
159 |
+
eps = 1e-9
|
160 |
+
input_alphas = torch.tensor([mod["input_alphas"] for mod in modality_info.values()])
|
161 |
+
input_alphas = rearrange(input_alphas, "nmod nmix -> nmix nmod")
|
162 |
+
self.input_dirichlets = [Dirichlet(torch.clamp(input_alpha, min=eps)) for input_alpha in input_alphas]
|
163 |
+
target_alphas = torch.tensor([mod["target_alphas"] for mod in modality_info.values()])
|
164 |
+
target_alphas = rearrange(target_alphas, "nmod nmix -> nmix nmod")
|
165 |
+
self.target_dirichlets = [Dirichlet(torch.clamp(target_alpha, min=eps)) for target_alpha in target_alphas]
|
166 |
+
assert(len(self.input_dirichlets) == len(self.target_dirichlets))
|
167 |
+
self.num_dirichlets = len(self.input_dirichlets)
|
168 |
+
if sampling_weights is not None:
|
169 |
+
assert len(sampling_weights) == self.num_dirichlets
|
170 |
+
self.sampling_weights = torch.tensor(sampling_weights)
|
171 |
+
else:
|
172 |
+
self.sampling_weights = None
|
173 |
+
|
174 |
+
self.text_tokenizer = text_tokenizer
|
175 |
+
self.keep_prob_decay_factor = 0.9
|
176 |
+
self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
|
177 |
+
self.sentinel_ids = set(self.sentinel_to_id.values())
|
178 |
+
self.pad_id = text_tokenizer.token_to_id("[PAD]")
|
179 |
+
self.eos_id = text_tokenizer.token_to_id("[EOS]")
|
180 |
+
|
181 |
+
def input_token_budget(self, num_input_tokens, dir_idx=0):
|
182 |
+
"""Sample a token budget for the input
|
183 |
+
|
184 |
+
Args:
|
185 |
+
num_input_tokens: Number of tokens in the input
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Token budget for the input
|
189 |
+
"""
|
190 |
+
# Get the number of tokens for each modality
|
191 |
+
for i in range(self.max_tries):
|
192 |
+
input_token_budget = (self.input_dirichlets[dir_idx].sample() * num_input_tokens).floor().int()
|
193 |
+
diff = num_input_tokens - input_token_budget.sum()
|
194 |
+
# Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
|
195 |
+
# This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
|
196 |
+
input_token_budget += torch.bincount(self.input_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(input_token_budget))
|
197 |
+
|
198 |
+
# If token budget is over max tokens for a given modality, set it to max
|
199 |
+
input_token_budget = torch.clamp(input_token_budget, max=self.max_tokens)
|
200 |
+
|
201 |
+
if (input_token_budget >= self.min_tokens).all():
|
202 |
+
return input_token_budget.tolist()
|
203 |
+
|
204 |
+
print(f"More than max tries for input!")
|
205 |
+
return input_token_budget.tolist()
|
206 |
+
|
207 |
+
def target_token_budget(self, input_token_budget, num_target_tokens, dir_idx=0):
|
208 |
+
"""Sample a token budget for the target
|
209 |
+
|
210 |
+
Args:
|
211 |
+
input_token_budget: Token budget for the input
|
212 |
+
num_target_tokens: Number of tokens in the target
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
Token budget for the target
|
216 |
+
"""
|
217 |
+
# We don't reduce the number of tokens for sequence based tasks
|
218 |
+
max_tokens_remaining = torch.where(self.mod_is_img, self.max_tokens - torch.tensor(input_token_budget), self.max_tokens)
|
219 |
+
max_tokens_remaining = torch.max(self.min_tokens, max_tokens_remaining)
|
220 |
+
for i in range(self.max_tries):
|
221 |
+
target_token_budget = (self.target_dirichlets[dir_idx].sample() * num_target_tokens).floor().int()
|
222 |
+
diff = num_target_tokens - target_token_budget.sum()
|
223 |
+
# Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
|
224 |
+
# This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
|
225 |
+
target_token_budget += torch.bincount(self.target_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(target_token_budget))
|
226 |
+
|
227 |
+
# If token budget is over max tokens for a given modality, set it to max
|
228 |
+
target_token_budget = torch.clamp(target_token_budget, max=max_tokens_remaining)
|
229 |
+
|
230 |
+
if (target_token_budget >= self.min_tokens).all():
|
231 |
+
return target_token_budget.tolist()
|
232 |
+
|
233 |
+
print(f"More than max tries for target!")
|
234 |
+
return target_token_budget.tolist()
|
235 |
+
|
236 |
+
def image_mask(self, tensor: torch.Tensor, num_tokens: int, input_budget: int, target_budget: int):
|
237 |
+
"""Applies input and target masking to an image tensor
|
238 |
+
|
239 |
+
Args:
|
240 |
+
tensor: Image tensor
|
241 |
+
num_tokens: Number of tokens in the tensor
|
242 |
+
input_budget: Token budget for the input
|
243 |
+
target_budget: Token budget for the target
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
|
247 |
+
"""
|
248 |
+
noise = torch.rand(num_tokens)
|
249 |
+
ids_shuffle = torch.argsort(noise, dim=0)
|
250 |
+
|
251 |
+
input_mask = torch.ones(num_tokens, dtype=torch.bool)
|
252 |
+
input_mask[:input_budget] = 0
|
253 |
+
input_mask = torch.gather(input_mask, dim=0, index=ids_shuffle)
|
254 |
+
|
255 |
+
if target_budget is None:
|
256 |
+
target_mask = ~input_mask
|
257 |
+
else:
|
258 |
+
target_mask = torch.ones(num_tokens, dtype=torch.bool)
|
259 |
+
target_mask[input_budget:input_budget + target_budget] = 0
|
260 |
+
target_mask = torch.gather(target_mask, dim=0, index=ids_shuffle)
|
261 |
+
|
262 |
+
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
|
263 |
+
first_mask_token = torch.argmin(target_mask + torch.arange(target_mask.shape[0], device=target_mask.device) * 1e-6)
|
264 |
+
decoder_attention_mask[first_mask_token] = (~target_mask).sum() # Equiv. to target budget
|
265 |
+
|
266 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
|
267 |
+
|
268 |
+
def sequence_token_mask(self, sequence_ids: str, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str, vocab_offset: int):
|
269 |
+
"""Applies input and target masking to a sequence of tokens (e.g. DINOv2 global tokens)
|
270 |
+
The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
|
271 |
+
If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
sequence_ids: Sequence ids
|
275 |
+
max_tokens: Maximum number of tokens in the sequence
|
276 |
+
input_budget: Token budget for the input
|
277 |
+
target_budget: Token budget for the target
|
278 |
+
keep_scheme: Scheme for sampling the keep probability
|
279 |
+
vocab_offset: Offset to avoid overlap with sentinel tokens
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
|
283 |
+
"""
|
284 |
+
seq_ids = sequence_ids
|
285 |
+
seq_ids = seq_ids + vocab_offset # Avoid overlap with sentinel tokens (needs to be substracted after decoding)
|
286 |
+
|
287 |
+
# If input budget is 0, treat it as if the whole sequence is completely masked
|
288 |
+
if input_budget == 0:
|
289 |
+
keep_prob = 0.
|
290 |
+
input_seq_ids = []
|
291 |
+
_, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
|
292 |
+
else:
|
293 |
+
if keep_scheme == 'random':
|
294 |
+
keep_prob = sample_uniform(0, 1)
|
295 |
+
elif keep_scheme == 'all':
|
296 |
+
keep_prob = 1.0
|
297 |
+
elif keep_scheme == 'binary':
|
298 |
+
keep_prob = random.choice([0., 1.])
|
299 |
+
else:
|
300 |
+
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
|
301 |
+
|
302 |
+
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
|
303 |
+
# Keep lowering the keep_prob while we are over-budget
|
304 |
+
while len(input_seq_ids) > input_budget:
|
305 |
+
keep_prob = keep_prob * self.keep_prob_decay_factor
|
306 |
+
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
|
307 |
+
|
308 |
+
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
|
309 |
+
max_length = (max_tokens + 1) * 2
|
310 |
+
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
|
311 |
+
input_mask = torch.ones(max_length, dtype=torch.bool)
|
312 |
+
target_mask = torch.ones(max_length, dtype=torch.bool)
|
313 |
+
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
|
314 |
+
|
315 |
+
# Set input and input mask
|
316 |
+
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
|
317 |
+
input_mask[:len(input_seq_ids)] = 0
|
318 |
+
|
319 |
+
if target_budget is None or len(target_seq_ids) <= target_budget:
|
320 |
+
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
|
321 |
+
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
|
322 |
+
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
|
323 |
+
else:
|
324 |
+
# Randomly choose sentinel token.
|
325 |
+
sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
|
326 |
+
# If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
|
327 |
+
chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
|
328 |
+
# If length starting at this token g.t. budget, truncate until budget is reached
|
329 |
+
if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
|
330 |
+
target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
|
331 |
+
# Otherwise, select earliest sentinel token such that we don't go over budget
|
332 |
+
# Note: We could also use the randomly chosen sentinel token, but that would waste budget
|
333 |
+
else:
|
334 |
+
for idx in sentinel_indices:
|
335 |
+
if len(target_seq_ids) - idx <= target_budget:
|
336 |
+
target_seq_ids = target_seq_ids[idx:]
|
337 |
+
break
|
338 |
+
|
339 |
+
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
|
340 |
+
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
|
341 |
+
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
|
342 |
+
|
343 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
|
344 |
+
|
345 |
+
def sequence_mask(self, sequence: Union[str, List[str]], max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
|
346 |
+
"""Applies input and target masking to a sequence
|
347 |
+
|
348 |
+
The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
|
349 |
+
If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
sequence: Sequence, can be either a str or list of strings
|
353 |
+
max_tokens: Maximum number of tokens in the sequence
|
354 |
+
input_budget: Token budget for the input
|
355 |
+
target_budget: Token budget for the target
|
356 |
+
keep_scheme: Scheme for sampling the keep probability
|
357 |
+
|
358 |
+
Returns:
|
359 |
+
Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
|
360 |
+
"""
|
361 |
+
if isinstance(sequence, str):
|
362 |
+
# Tokenize the sequence and get the ids
|
363 |
+
seq_ids: List[int] = self.text_tokenizer.encode(sequence).ids
|
364 |
+
# Add EOS to all sequences
|
365 |
+
seq_ids.append(self.eos_id)
|
366 |
+
# Truncate sequence
|
367 |
+
seq_ids = seq_ids[:max_tokens]
|
368 |
+
|
369 |
+
# Use default span masking
|
370 |
+
span_masking_fn = simple_span_masking
|
371 |
+
|
372 |
+
elif isinstance(sequence, list):
|
373 |
+
# Tokenize the sequence chunks and get the ids
|
374 |
+
encoded_seq_chunks = self.text_tokenizer.encode_batch(sequence)
|
375 |
+
seq_ids: List[List[int]] = [seq.ids for seq in encoded_seq_chunks]
|
376 |
+
# Add EOS as an extra chunk
|
377 |
+
seq_ids.append([self.eos_id])
|
378 |
+
# Truncate sequence to keep all chunks below max token length
|
379 |
+
cumulative_token_count = np.cumsum(np.array([len(chunk) for chunk in seq_ids]))
|
380 |
+
seq_ids = [chunk for (chunk, token_count) in zip(seq_ids, cumulative_token_count) if token_count <= max_tokens]
|
381 |
+
|
382 |
+
# Span mask over chunks
|
383 |
+
span_masking_fn = chunk_span_masking
|
384 |
+
|
385 |
+
else:
|
386 |
+
raise ValueError(f"Invalid sequence: {sequence}")
|
387 |
+
|
388 |
+
|
389 |
+
# If input budget is 0, treat it as if the whole sequence is completely masked
|
390 |
+
if input_budget == 0:
|
391 |
+
keep_prob = 0.
|
392 |
+
input_seq_ids = []
|
393 |
+
_, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
|
394 |
+
else:
|
395 |
+
if keep_scheme == 'random':
|
396 |
+
keep_prob = sample_uniform(0, 1)
|
397 |
+
elif keep_scheme == 'all':
|
398 |
+
keep_prob = 1.0
|
399 |
+
elif keep_scheme == 'binary':
|
400 |
+
keep_prob = random.choice([0., 1.])
|
401 |
+
else:
|
402 |
+
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
|
403 |
+
|
404 |
+
input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
|
405 |
+
# Keep lowering the keep_prob while we are over-budget
|
406 |
+
while len(input_seq_ids) > input_budget:
|
407 |
+
keep_prob = keep_prob * self.keep_prob_decay_factor
|
408 |
+
input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
|
409 |
+
|
410 |
+
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
|
411 |
+
max_length = (max_tokens + 1) * 2
|
412 |
+
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
|
413 |
+
input_mask = torch.ones(max_length, dtype=torch.bool)
|
414 |
+
target_mask = torch.ones(max_length, dtype=torch.bool)
|
415 |
+
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
|
416 |
+
|
417 |
+
# Set input and input mask
|
418 |
+
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
|
419 |
+
input_mask[:len(input_seq_ids)] = 0
|
420 |
+
|
421 |
+
if target_budget is None or len(target_seq_ids) <= target_budget:
|
422 |
+
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
|
423 |
+
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
|
424 |
+
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
|
425 |
+
else:
|
426 |
+
# Randomly choose sentinel token.
|
427 |
+
sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
|
428 |
+
# If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
|
429 |
+
chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
|
430 |
+
# If length starting at this token g.t. budget, truncate until budget is reached
|
431 |
+
if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
|
432 |
+
target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
|
433 |
+
# Otherwise, select earliest sentinel token such that we don't go over budget
|
434 |
+
# Note: We could also use the randomly chosen sentinel token, but that would waste budget
|
435 |
+
else:
|
436 |
+
for idx in sentinel_indices:
|
437 |
+
if len(target_seq_ids) - idx <= target_budget:
|
438 |
+
target_seq_ids = target_seq_ids[idx:]
|
439 |
+
break
|
440 |
+
|
441 |
+
tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
|
442 |
+
target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
|
443 |
+
decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
|
444 |
+
|
445 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
|
446 |
+
|
447 |
+
|
448 |
+
def sequence_emb_mask_span(self, emb_tensor: torch.Tensor, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
|
449 |
+
"""Applies input masking to an sequence embedding tensor, target masking is not supported with sequence embeddings
|
450 |
+
|
451 |
+
Args:
|
452 |
+
emb_tensor: Sequence embedding tensor
|
453 |
+
max_tokens: Maximum number of tokens in the sequence
|
454 |
+
input_budget: Token budget for the input
|
455 |
+
target_budget: Token budget for the target (unused for now)
|
456 |
+
keep_scheme: Scheme for sampling the keep probability
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
Dictionary containing the masked sequence embedding tensor, the input mask, the target mask, and the decoder attention mask
|
460 |
+
"""
|
461 |
+
# Only supported as input modality now
|
462 |
+
|
463 |
+
# Make fake seq ids for sequence embeddings to reuse simple_span_masking function
|
464 |
+
fake_seq_ids = []
|
465 |
+
emb_dict = {}
|
466 |
+
id_num = len(self.sentinel_ids)
|
467 |
+
emb_ind = 0
|
468 |
+
while(len(fake_seq_ids) < len(emb_tensor)):
|
469 |
+
if id_num not in self.sentinel_ids: # replace with T5 sentinel_id
|
470 |
+
fake_seq_ids.append(id_num)
|
471 |
+
emb_dict[id_num] = emb_tensor[emb_ind, :]
|
472 |
+
emb_ind += 1
|
473 |
+
id_num += 1
|
474 |
+
|
475 |
+
# Truncate sequence
|
476 |
+
fake_seq_ids = fake_seq_ids[:max_tokens]
|
477 |
+
|
478 |
+
# If input budget is 0, treat it as if the whole sequence is completely masked
|
479 |
+
if input_budget == 0:
|
480 |
+
keep_prob = 0.
|
481 |
+
fake_input_seq_ids = []
|
482 |
+
_, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
|
483 |
+
else:
|
484 |
+
if keep_scheme == 'random':
|
485 |
+
keep_prob = sample_uniform(0, 1)
|
486 |
+
elif keep_scheme == 'all':
|
487 |
+
keep_prob = 1.0
|
488 |
+
elif keep_scheme == 'binary':
|
489 |
+
keep_prob = random.choice([0., 1.])
|
490 |
+
else:
|
491 |
+
raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
|
492 |
+
|
493 |
+
fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
|
494 |
+
# Keep lowering the keep_prob while we are over-budget
|
495 |
+
while len(fake_input_seq_ids) > input_budget:
|
496 |
+
keep_prob = keep_prob * self.keep_prob_decay_factor
|
497 |
+
fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
|
498 |
+
|
499 |
+
# Span masking can add up to max_tokens tokens for input
|
500 |
+
max_length = max_tokens
|
501 |
+
tensor = torch.zeros((max_length, emb_tensor.shape[1]), dtype=torch.float32)
|
502 |
+
input_mask = torch.ones(max_length, dtype=torch.bool)
|
503 |
+
target_mask = torch.ones(max_length, dtype=torch.bool)
|
504 |
+
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
|
505 |
+
|
506 |
+
# Put tensor values back based on the fake seq ids
|
507 |
+
for i_, fake_id in enumerate(fake_input_seq_ids):
|
508 |
+
if fake_id in self.sentinel_ids:
|
509 |
+
tensor[i_, :] = torch.zeros_like(emb_tensor[0,:]) # TODO replace to learned embeddings later
|
510 |
+
else:
|
511 |
+
tensor[i_, :] = emb_dict[fake_id]
|
512 |
+
|
513 |
+
# Set input and input mask
|
514 |
+
input_mask[:len(fake_input_seq_ids)] = 0
|
515 |
+
|
516 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
|
517 |
+
|
518 |
+
|
519 |
+
def __call__(self, mod_dict):
|
520 |
+
"""Applies input and target masking to a dictionary of modalities
|
521 |
+
|
522 |
+
Args:
|
523 |
+
mod_dict: Dictionary of modalities
|
524 |
+
|
525 |
+
Returns:
|
526 |
+
Dictionary containing the masked modalities
|
527 |
+
"""
|
528 |
+
if self.sampling_weights is not None:
|
529 |
+
# Sample masking scheme according to a list of weights
|
530 |
+
dir_idx = torch.multinomial(self.sampling_weights, 1).item()
|
531 |
+
else:
|
532 |
+
# Randomly sample masking scheme
|
533 |
+
dir_idx = random.randint(0, self.num_dirichlets - 1)
|
534 |
+
|
535 |
+
num_input_tokens = random.randint(*self.input_tokens_range)
|
536 |
+
num_target_tokens = random.randint(*self.target_tokens_range) if self.target_tokens_range is not None else None
|
537 |
+
|
538 |
+
input_token_budget = self.input_token_budget(num_input_tokens, dir_idx)
|
539 |
+
|
540 |
+
if num_target_tokens is not None:
|
541 |
+
target_token_budget = self.target_token_budget(input_token_budget, num_target_tokens, dir_idx)
|
542 |
+
else:
|
543 |
+
target_token_budget = [None] * self.num_modalities
|
544 |
+
|
545 |
+
masked_mod_dict = {}
|
546 |
+
for (mod_name, mod_info), input_budget, target_budget in zip(self.modality_info.items(), input_token_budget, target_token_budget):
|
547 |
+
mod_type = mod_info['type']
|
548 |
+
mod_name_load = mod_name if mod_name in mod_dict else get_transform_key(mod_name)
|
549 |
+
if mod_type == 'img':
|
550 |
+
masked_mod_dict[mod_name] = self.image_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget)
|
551 |
+
elif mod_type == 'seq':
|
552 |
+
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
|
553 |
+
masked_mod_dict[mod_name] = self.sequence_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
|
554 |
+
elif mod_type == 'seq_token':
|
555 |
+
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
|
556 |
+
vocab_offset = mod_info.get('vocab_offset', 0) # Check if any space is allocated to sentinel tokens and other special tokens
|
557 |
+
masked_mod_dict[mod_name] = self.sequence_token_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme, vocab_offset=vocab_offset)
|
558 |
+
elif mod_type == "seq_emb":
|
559 |
+
keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
|
560 |
+
masked_mod_dict[mod_name] = self.sequence_emb_mask_span(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
|
561 |
+
else:
|
562 |
+
raise ValueError(f"Invalid modality type: {mod_type}")
|
563 |
+
|
564 |
+
return masked_mod_dict
|
565 |
+
|
566 |
+
|
567 |
+
class TransferMasking(object):
|
568 |
+
def __init__(self,
|
569 |
+
modality_info: Dict,
|
570 |
+
text_tokenizer: Optional[Tokenizer],
|
571 |
+
input_modalities: List[str],
|
572 |
+
target_modalities: List[str]):
|
573 |
+
"""Performs masking for transfer on a dict of modalities (both image based and sequence based modalities),
|
574 |
+
by specifying which modalities are inputs and which are targets.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
modality_info: Dict with the modalities and their corresponding information
|
578 |
+
text_tokenizer: Tokenizer to use for text modalities
|
579 |
+
input_modalities: List of modalities to use as input
|
580 |
+
target_modalities: List of modalities to use as target
|
581 |
+
"""
|
582 |
+
self.modality_info = modality_info
|
583 |
+
self.num_modalities = len(modality_info)
|
584 |
+
self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
|
585 |
+
self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
|
586 |
+
self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])
|
587 |
+
|
588 |
+
self.input_modalities = set(input_modalities)
|
589 |
+
self.target_modalities = set(target_modalities)
|
590 |
+
|
591 |
+
# Tokenizer for text modalities
|
592 |
+
self.text_tokenizer = text_tokenizer
|
593 |
+
if self.text_tokenizer is not None:
|
594 |
+
self.keep_prob_decay_factor = 0.9
|
595 |
+
self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
|
596 |
+
self.sentinel_ids = set(self.sentinel_to_id.values())
|
597 |
+
self.pad_id = text_tokenizer.token_to_id("[PAD]")
|
598 |
+
self.eos_id = text_tokenizer.token_to_id("[EOS]")
|
599 |
+
|
600 |
+
|
601 |
+
def input_image(self, tensor: torch.Tensor, num_tokens: int):
|
602 |
+
"""Applies masking for an image given as input
|
603 |
+
|
604 |
+
Args:
|
605 |
+
tensor: Image tensor
|
606 |
+
num_tokens: Number of tokens in the tensor
|
607 |
+
|
608 |
+
Returns:
|
609 |
+
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
|
610 |
+
"""
|
611 |
+
|
612 |
+
# Input mask
|
613 |
+
input_mask = torch.zeros(num_tokens, dtype=torch.bool)
|
614 |
+
# Target mask
|
615 |
+
target_mask = torch.ones(num_tokens, dtype=torch.bool)
|
616 |
+
# Decoder attention mask
|
617 |
+
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
|
618 |
+
|
619 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
|
620 |
+
|
621 |
+
def target_image(self, tensor: torch.Tensor, num_tokens: int):
|
622 |
+
"""Applies masking for an image given as target
|
623 |
+
|
624 |
+
Args:
|
625 |
+
tensor: Image tensor
|
626 |
+
num_tokens: Number of tokens in the tensor
|
627 |
+
|
628 |
+
Returns:
|
629 |
+
Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
|
630 |
+
"""
|
631 |
+
|
632 |
+
# Input mask
|
633 |
+
input_mask = torch.ones(num_tokens, dtype=torch.bool)
|
634 |
+
# Target mask
|
635 |
+
target_mask = torch.zeros(num_tokens, dtype=torch.bool)
|
636 |
+
# Decoder attention mask
|
637 |
+
decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
|
638 |
+
decoder_attention_mask[0] = num_tokens
|
639 |
+
|
640 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
|
641 |
+
|
642 |
+
|
643 |
+
def input_sequence(self, sequence_str: str, max_tokens: int):
|
644 |
+
"""Applies masking for a sequence given as input
|
645 |
+
|
646 |
+
Args:
|
647 |
+
sequence_str: Sequence string
|
648 |
+
max_tokens: Maximum number of tokens in the sequence
|
649 |
+
|
650 |
+
Returns:
|
651 |
+
Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
|
652 |
+
"""
|
653 |
+
# Tokenize the text and get the ids
|
654 |
+
seq_ids = self.text_tokenizer.encode(sequence_str).ids
|
655 |
+
# Add EOS to all sequences
|
656 |
+
seq_ids.append(self.eos_id)
|
657 |
+
# Truncate sequence
|
658 |
+
seq_ids = seq_ids[:max_tokens]
|
659 |
+
|
660 |
+
keep_prob = 1.
|
661 |
+
input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
|
662 |
+
|
663 |
+
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
|
664 |
+
max_length = (max_tokens + 1) * 2
|
665 |
+
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
|
666 |
+
input_mask = torch.ones(max_length, dtype=torch.bool)
|
667 |
+
target_mask = torch.ones(max_length, dtype=torch.bool)
|
668 |
+
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
|
669 |
+
|
670 |
+
# Set input and input mask
|
671 |
+
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
|
672 |
+
input_mask[:len(input_seq_ids)] = 0
|
673 |
+
|
674 |
+
tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
|
675 |
+
target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
|
676 |
+
decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1
|
677 |
+
|
678 |
+
|
679 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
|
680 |
+
|
681 |
+
|
682 |
+
def target_sequence(self, sequence_str: str, max_tokens: int):
|
683 |
+
"""Applies masking for a sequence given as target
|
684 |
+
|
685 |
+
Args:
|
686 |
+
sequence_str: Sequence string
|
687 |
+
max_tokens: Maximum number of tokens in the sequence
|
688 |
+
|
689 |
+
Returns:
|
690 |
+
Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
|
691 |
+
"""
|
692 |
+
# Tokenize the text and get the ids
|
693 |
+
seq_ids = self.text_tokenizer.encode(sequence_str).ids
|
694 |
+
# Add EOS to all sequences
|
695 |
+
seq_ids.append(self.eos_id)
|
696 |
+
# Truncate sequence
|
697 |
+
seq_ids = seq_ids[:max_tokens]
|
698 |
+
|
699 |
+
keep_prob = 0.
|
700 |
+
input_seq_ids = []
|
701 |
+
_, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
|
702 |
+
|
703 |
+
# Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
|
704 |
+
max_length = (max_tokens + 1) * 2
|
705 |
+
tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
|
706 |
+
input_mask = torch.ones(max_length, dtype=torch.bool)
|
707 |
+
target_mask = torch.ones(max_length, dtype=torch.bool)
|
708 |
+
decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
|
709 |
+
|
710 |
+
# Set input and input mask
|
711 |
+
tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
|
712 |
+
input_mask[:len(input_seq_ids)] = 0
|
713 |
+
|
714 |
+
tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
|
715 |
+
target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
|
716 |
+
decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1
|
717 |
+
|
718 |
+
return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask,
|
719 |
+
"decoder_attention_mask": decoder_attention_mask}
|
720 |
+
|
721 |
+
def __call__(self, mod_dict):
|
722 |
+
"""Applies input and target masking to a dictionary of modalities
|
723 |
+
|
724 |
+
Args:
|
725 |
+
mod_dict: Dictionary of modalities
|
726 |
+
|
727 |
+
Returns:
|
728 |
+
Dictionary containing the masked modalities
|
729 |
+
"""
|
730 |
+
masked_mod_dict = {}
|
731 |
+
for mod_name, mod_info in self.modality_info.items():
|
732 |
+
mod_type = mod_info['type']
|
733 |
+
if mod_type == 'img' and mod_name in self.input_modalities:
|
734 |
+
masked_mod_dict[mod_name] = self.input_image(mod_dict[mod_name], mod_info['max_tokens'])
|
735 |
+
elif mod_type == 'img' and mod_name in self.target_modalities:
|
736 |
+
masked_mod_dict[mod_name] = self.target_image(mod_dict[mod_name], mod_info['max_tokens'])
|
737 |
+
elif mod_type == 'seq' and mod_name in self.input_modalities:
|
738 |
+
masked_mod_dict[mod_name] = self.input_sequence(mod_dict[mod_name], mod_info['max_tokens'])
|
739 |
+
elif mod_type == 'seq' and mod_name in self.target_modalities:
|
740 |
+
masked_mod_dict[mod_name] = self.target_sequence(mod_dict[mod_name], mod_info['max_tokens'])
|
741 |
+
else:
|
742 |
+
raise ValueError(f"Invalid modality type: {mod_type} or modality name not in input or target modalities: {mod_name}")
|
743 |
+
|
744 |
+
if 'mask_valid' in mod_dict:
|
745 |
+
masked_mod_dict['mask_valid'] = mod_dict['mask_valid']
|
746 |
+
|
747 |
+
return masked_mod_dict
|
fourm/data/modality_info.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
import fourm.utils.data_constants as data_constants
|
17 |
+
from fourm.data.modality_transforms import (CaptionTransform, DepthTransform,
|
18 |
+
DetectionTransform, MaskTransform,
|
19 |
+
NormalTransform, RGBTransform,
|
20 |
+
SemsegTransform, TokTransform,
|
21 |
+
CaptionEmbTransform, MetadataTransform,
|
22 |
+
HumanPoseTransform, ColorPaletteTransform,
|
23 |
+
SAMInstanceTokTransform, SAMInstanceTransform)
|
24 |
+
from fourm.models.decoder_embeddings import (ImageTokenDecoderEmbedding,
|
25 |
+
SequenceDecoderEmbedding)
|
26 |
+
from fourm.models.encoder_embeddings import (ImageEncoderEmbedding,
|
27 |
+
ImageTokenEncoderEmbedding,
|
28 |
+
SequenceEncoderEmbedding,
|
29 |
+
SequenceEmbEncoderEmbedding)
|
30 |
+
from fourm.utils import generate_uint15_hash
|
31 |
+
|
32 |
+
MODALITY_INFO = {
|
33 |
+
# 4M-7 modalities
|
34 |
+
'rgb@224': {
|
35 |
+
'input_size': 224,
|
36 |
+
'patch_size': 16,
|
37 |
+
'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3),
|
38 |
+
'decoder_embedding': None,
|
39 |
+
'min_tokens': 0,
|
40 |
+
'max_tokens': None, # Will be set to 196
|
41 |
+
'type': 'img',
|
42 |
+
'num_channels': 3,
|
43 |
+
'id': generate_uint15_hash('rgb@224'),
|
44 |
+
'path': 'rgb',
|
45 |
+
},
|
46 |
+
'rgb': { # used for tokenizer training
|
47 |
+
'type': 'img',
|
48 |
+
'num_channels': 3,
|
49 |
+
'id': generate_uint15_hash('rgb'),
|
50 |
+
'path': 'rgb',
|
51 |
+
},
|
52 |
+
'caption': {
|
53 |
+
'vocab_size': 30_000,
|
54 |
+
'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
|
55 |
+
'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
|
56 |
+
'min_tokens': 0,
|
57 |
+
'max_tokens': 256,
|
58 |
+
'type': 'seq',
|
59 |
+
'id': generate_uint15_hash('caption'),
|
60 |
+
},
|
61 |
+
'det': {
|
62 |
+
'vocab_size': 30_000,
|
63 |
+
'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
|
64 |
+
'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
|
65 |
+
'min_tokens': 0,
|
66 |
+
'max_tokens': 256,
|
67 |
+
'type': 'seq',
|
68 |
+
'id': generate_uint15_hash('det'),
|
69 |
+
},
|
70 |
+
'tok_rgb@224': {
|
71 |
+
'input_size': 224,
|
72 |
+
'patch_size': 16,
|
73 |
+
'vocab_size': 16384,
|
74 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384),
|
75 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384),
|
76 |
+
'min_tokens': 0,
|
77 |
+
'max_tokens': None, # Will be set to 196
|
78 |
+
'type': 'img',
|
79 |
+
'id': generate_uint15_hash('tok_rgb@224'),
|
80 |
+
'pretokenized': True,
|
81 |
+
},
|
82 |
+
'tok_depth@224': {
|
83 |
+
'input_size': 224,
|
84 |
+
'patch_size': 16,
|
85 |
+
'vocab_size': 8192,
|
86 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
87 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
88 |
+
'min_tokens': 0,
|
89 |
+
'max_tokens': None, # Will be set to 196
|
90 |
+
'type': 'img',
|
91 |
+
'id': generate_uint15_hash('tok_depth@224'),
|
92 |
+
'pretokenized': True,
|
93 |
+
},
|
94 |
+
'depth': { # used for tokenizer training
|
95 |
+
'type': 'img',
|
96 |
+
'num_channels': 1,
|
97 |
+
'id': generate_uint15_hash('depth'),
|
98 |
+
},
|
99 |
+
'tok_normal@224': {
|
100 |
+
'input_size': 224,
|
101 |
+
'patch_size': 16,
|
102 |
+
'vocab_size': 8192,
|
103 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
104 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
105 |
+
'min_tokens': 0,
|
106 |
+
'max_tokens': None, # Will be set to 196
|
107 |
+
'type': 'img',
|
108 |
+
'id': generate_uint15_hash('tok_normal@224'),
|
109 |
+
'pretokenized': True,
|
110 |
+
},
|
111 |
+
'normal': { # used for tokenizer training
|
112 |
+
'type': 'img',
|
113 |
+
'num_channels': 3,
|
114 |
+
'id': generate_uint15_hash('normal'),
|
115 |
+
},
|
116 |
+
'tok_semseg@224': {
|
117 |
+
'input_size': 224,
|
118 |
+
'patch_size': 16,
|
119 |
+
'vocab_size': 4096,
|
120 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096),
|
121 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096),
|
122 |
+
'min_tokens': 0,
|
123 |
+
'max_tokens': None, # Will be set to 196
|
124 |
+
'type': 'img',
|
125 |
+
'id': generate_uint15_hash('tok_semseg@224'),
|
126 |
+
'pretokenized': True,
|
127 |
+
},
|
128 |
+
'semseg_coco': { # used for tokenizer training
|
129 |
+
'type': 'img',
|
130 |
+
'num_channels': 64,
|
131 |
+
'num_labels': data_constants.COCO_SEMSEG_NUM_CLASSES,
|
132 |
+
'id': generate_uint15_hash('semseg_coco'),
|
133 |
+
},
|
134 |
+
'tok_clip@224': {
|
135 |
+
'input_size': 224,
|
136 |
+
'patch_size': 16,
|
137 |
+
'vocab_size': 8192,
|
138 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
139 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
140 |
+
'min_tokens': 0,
|
141 |
+
'max_tokens': None, # Will be set to 196
|
142 |
+
'type': 'img',
|
143 |
+
'id': generate_uint15_hash('tok_clip@224'),
|
144 |
+
'pretokenized': True,
|
145 |
+
},
|
146 |
+
'CLIP-B16': { # used for tokenizer training
|
147 |
+
'type': 'feature_map',
|
148 |
+
'num_channels': 512,
|
149 |
+
'id': generate_uint15_hash('CLIP-B16'),
|
150 |
+
},
|
151 |
+
|
152 |
+
# 4M-21 modalities
|
153 |
+
't5_caption': {
|
154 |
+
'encoder_embedding': partial(SequenceEmbEncoderEmbedding, max_length=77, padding_idx=0),
|
155 |
+
'decoder_embedding': None,
|
156 |
+
'min_tokens': 0,
|
157 |
+
'max_tokens': 77,
|
158 |
+
'type': 'seq_emb',
|
159 |
+
'id': generate_uint15_hash('t5_caption'),
|
160 |
+
},
|
161 |
+
'metadata': {
|
162 |
+
'vocab_size': 30_000,
|
163 |
+
'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True),
|
164 |
+
'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True),
|
165 |
+
'min_tokens': 0,
|
166 |
+
'max_tokens': 40, # At most 2x19=38 for 19 metadata types, +1 for EOS, +1 for sentinel
|
167 |
+
'type': 'seq',
|
168 |
+
'id': generate_uint15_hash('metadata'),
|
169 |
+
'shared_vocab': ['caption'],
|
170 |
+
'path': 'metadata',
|
171 |
+
},
|
172 |
+
'human_poses': {
|
173 |
+
'vocab_size': 30_000,
|
174 |
+
'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True),
|
175 |
+
'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True),
|
176 |
+
'min_tokens': 0,
|
177 |
+
'max_tokens': 275, #7*39+1 EOS+1 S_1#263, #261 in one of the models, or 263 to have EOS #261+1+1 #238,
|
178 |
+
'type': 'seq',
|
179 |
+
'num_channels': 207, # for tokenization training, only the pose part is needed
|
180 |
+
'id': generate_uint15_hash('human_poses'),
|
181 |
+
'shared_vocab': ['caption'],
|
182 |
+
},
|
183 |
+
'color_palette': {
|
184 |
+
'vocab_size': 30_000,
|
185 |
+
'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True),
|
186 |
+
'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True),
|
187 |
+
'min_tokens': 0,
|
188 |
+
'max_tokens': 23, #7x3=21 for 7 colors, +1 for EOS, +1 for sentinel
|
189 |
+
'type': 'seq',
|
190 |
+
'id': generate_uint15_hash('color_palette'),
|
191 |
+
'shared_vocab': ['caption'],
|
192 |
+
'path': 'color_palette',
|
193 |
+
},
|
194 |
+
'sam_mask': {
|
195 |
+
'encoder_embedding': None,
|
196 |
+
'decoder_embedding': None,
|
197 |
+
'min_tokens': 0,
|
198 |
+
'max_tokens': 64,
|
199 |
+
'type': 'img',
|
200 |
+
'num_channels': 1,
|
201 |
+
'id': generate_uint15_hash('sam_mask'),
|
202 |
+
},
|
203 |
+
'sam_instance': {
|
204 |
+
'vocab_size': 30_000,
|
205 |
+
'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True),
|
206 |
+
'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True),
|
207 |
+
'min_tokens': 0,
|
208 |
+
'max_tokens': 290,
|
209 |
+
'type': 'seq',
|
210 |
+
'id': generate_uint15_hash('sam_instance'),
|
211 |
+
'shared_vocab': ['caption'],
|
212 |
+
'pretokenized': True,
|
213 |
+
},
|
214 |
+
'tok_canny_edge@224': {
|
215 |
+
'input_size': 224,
|
216 |
+
'patch_size': 16,
|
217 |
+
'vocab_size': 8192,
|
218 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
219 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
220 |
+
'min_tokens': 0,
|
221 |
+
'max_tokens': None, # Will be set to 196
|
222 |
+
'type': 'img',
|
223 |
+
'id': generate_uint15_hash('tok_canny_edge@224'),
|
224 |
+
'pretokenized': True,
|
225 |
+
},
|
226 |
+
'canny_edge': { # used for tokenizer training
|
227 |
+
'type': 'img',
|
228 |
+
'num_channels': 1,
|
229 |
+
'id': generate_uint15_hash('canny_edge'),
|
230 |
+
},
|
231 |
+
'tok_sam_edge@224': {
|
232 |
+
'input_size': 224,
|
233 |
+
'patch_size': 16,
|
234 |
+
'vocab_size': 8192,
|
235 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
236 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
237 |
+
'min_tokens': 0,
|
238 |
+
'max_tokens': None, # Will be set to 196
|
239 |
+
'type': 'img',
|
240 |
+
'id': generate_uint15_hash('tok_sam_edge@224'),
|
241 |
+
'pretokenized': True,
|
242 |
+
},
|
243 |
+
'tok_dinov2@224': {
|
244 |
+
'input_size': 224,
|
245 |
+
'patch_size': 14,
|
246 |
+
'vocab_size': 8192,
|
247 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
248 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
249 |
+
'min_tokens': 0,
|
250 |
+
'max_tokens': None, # Will be set to 256
|
251 |
+
'type': 'img',
|
252 |
+
'id': generate_uint15_hash('tok_dinov2@224'),
|
253 |
+
'pretokenized': True,
|
254 |
+
},
|
255 |
+
'DINOv2-B14': { # used for tokenizer training
|
256 |
+
'type': 'feature_map',
|
257 |
+
'num_channels': 768,
|
258 |
+
'id': generate_uint15_hash('DINOv2-B14'),
|
259 |
+
},
|
260 |
+
'tok_imagebind@224': {
|
261 |
+
'input_size': 224,
|
262 |
+
'patch_size': 14,
|
263 |
+
'vocab_size': 8192,
|
264 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
265 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
266 |
+
'min_tokens': 0,
|
267 |
+
'max_tokens': None, # Will be set to 256
|
268 |
+
'type': 'img',
|
269 |
+
'id': generate_uint15_hash('tok_imagebind@224'),
|
270 |
+
'pretokenized': True,
|
271 |
+
},
|
272 |
+
'ImageBind-H14': { # used for tokenizer training
|
273 |
+
'type': 'feature_map',
|
274 |
+
'num_channels': 1280,
|
275 |
+
'id': generate_uint15_hash('ImageBind-H14'),
|
276 |
+
},
|
277 |
+
'tok_dinov2_global': {
|
278 |
+
'vocab_size': 8192,
|
279 |
+
'patch_size': 56,
|
280 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
|
281 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
|
282 |
+
'min_tokens': 0,
|
283 |
+
'max_tokens': 16,
|
284 |
+
'type': 'img',
|
285 |
+
'id': generate_uint15_hash('tok_dinov2_global'),
|
286 |
+
'pretokenized': True,
|
287 |
+
},
|
288 |
+
'DINOv2-B14-global': { # used for tokenizer training
|
289 |
+
'type': 'feature_map',
|
290 |
+
'num_channels': 768,
|
291 |
+
'id': generate_uint15_hash('DINOv2-B14-global'),
|
292 |
+
},
|
293 |
+
'tok_imagebind_global': {
|
294 |
+
'vocab_size': 8192,
|
295 |
+
'patch_size': 56,
|
296 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
|
297 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
|
298 |
+
'min_tokens': 0,
|
299 |
+
'max_tokens': 16,
|
300 |
+
'type': 'img',
|
301 |
+
'id': generate_uint15_hash('tok_imagebind_global'),
|
302 |
+
'pretokenized': True,
|
303 |
+
},
|
304 |
+
'ImageBind-H14-global': { # used for tokenizer training
|
305 |
+
'type': 'feature_map',
|
306 |
+
'num_channels': 1280,
|
307 |
+
'id': generate_uint15_hash('ImageBind-H14-global'),
|
308 |
+
},
|
309 |
+
|
310 |
+
### 224->448 super resolution modalities
|
311 |
+
'rgb@448': {
|
312 |
+
'input_size': 448,
|
313 |
+
'patch_size': 16,
|
314 |
+
'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3),
|
315 |
+
'decoder_embedding': None,
|
316 |
+
'min_tokens': 0,
|
317 |
+
'max_tokens': None, # Will be set to 784
|
318 |
+
'type': 'img',
|
319 |
+
'num_channels': 3,
|
320 |
+
'id': generate_uint15_hash('rgb@448'),
|
321 |
+
'path': 'rgb',
|
322 |
+
},
|
323 |
+
'tok_rgb@448': {
|
324 |
+
'input_size': 448,
|
325 |
+
'patch_size': 16,
|
326 |
+
'vocab_size': 16384,
|
327 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384),
|
328 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384),
|
329 |
+
'min_tokens': 0,
|
330 |
+
'max_tokens': None, # Will be set to 784
|
331 |
+
'type': 'img',
|
332 |
+
'id': generate_uint15_hash('tok_rgb@448'),
|
333 |
+
'pretokenized': True,
|
334 |
+
},
|
335 |
+
'tok_depth@448': {
|
336 |
+
'input_size': 448,
|
337 |
+
'patch_size': 16,
|
338 |
+
'vocab_size': 8192,
|
339 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
340 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
341 |
+
'min_tokens': 0,
|
342 |
+
'max_tokens': None, # Will be set to 784
|
343 |
+
'type': 'img',
|
344 |
+
'id': generate_uint15_hash('tok_depth@448'),
|
345 |
+
'pretokenized': True,
|
346 |
+
},
|
347 |
+
'tok_normal@448': {
|
348 |
+
'input_size': 448,
|
349 |
+
'patch_size': 16,
|
350 |
+
'vocab_size': 8192,
|
351 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
352 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
353 |
+
'min_tokens': 0,
|
354 |
+
'max_tokens': None, # Will be set to 784
|
355 |
+
'type': 'img',
|
356 |
+
'id': generate_uint15_hash('tok_normal@448'),
|
357 |
+
'pretokenized': True,
|
358 |
+
},
|
359 |
+
'tok_semseg@448': {
|
360 |
+
'input_size': 448,
|
361 |
+
'patch_size': 16,
|
362 |
+
'vocab_size': 4096,
|
363 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096),
|
364 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096),
|
365 |
+
'min_tokens': 0,
|
366 |
+
'max_tokens': None, # Will be set to 784
|
367 |
+
'type': 'img',
|
368 |
+
'id': generate_uint15_hash('tok_semseg@448'),
|
369 |
+
'pretokenized': True,
|
370 |
+
},
|
371 |
+
'tok_clip@448': {
|
372 |
+
'input_size': 448,
|
373 |
+
'patch_size': 16,
|
374 |
+
'vocab_size': 8192,
|
375 |
+
'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
|
376 |
+
'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
|
377 |
+
'min_tokens': 0,
|
378 |
+
'max_tokens': None, # Will be set to 784
|
379 |
+
'type': 'img',
|
380 |
+
'id': generate_uint15_hash('tok_clip@448'),
|
381 |
+
'pretokenized': True,
|
382 |
+
},
|
383 |
+
}
|
384 |
+
|
385 |
+
# Note: @res suffix is ignored for modality transforms
|
386 |
+
MODALITY_TRANSFORMS = {
|
387 |
+
# 4M-7 modalities
|
388 |
+
'rgb': RGBTransform(imagenet_default_mean_and_std=True),
|
389 |
+
'caption': CaptionTransform(aligned_captions=True),
|
390 |
+
'det': DetectionTransform(det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0),
|
391 |
+
'tok_rgb': TokTransform(),
|
392 |
+
'tok_depth': TokTransform(),
|
393 |
+
'tok_normal': TokTransform(),
|
394 |
+
'tok_semseg': TokTransform(),
|
395 |
+
'tok_clip': TokTransform(),
|
396 |
+
# 4M-21 modalities
|
397 |
+
't5_caption': CaptionEmbTransform(),
|
398 |
+
'metadata': MetadataTransform(special_vmin=0, special_vmax=999, shuffle=True, random_trunc=False, return_chunks=True),
|
399 |
+
'human_poses': HumanPoseTransform(coord_bins=1000),
|
400 |
+
'color_palette': ColorPaletteTransform(coord_bins=1000),
|
401 |
+
'sam_instance': SAMInstanceTokTransform(image_size=224, points_per_side=7, point_order='random'),
|
402 |
+
'tok_canny_edge': TokTransform(),
|
403 |
+
'tok_sam_edge': TokTransform(),
|
404 |
+
'tok_dinov2': TokTransform(),
|
405 |
+
'tok_imagebind': TokTransform(),
|
406 |
+
'tok_dinov2_global': TokTransform(),
|
407 |
+
'tok_imagebind_global': TokTransform(),
|
408 |
+
# Other
|
409 |
+
'mask_valid': MaskTransform(mask_pool_size=1),
|
410 |
+
}
|
411 |
+
|
412 |
+
MODALITY_TRANSFORMS_DIVAE = {
|
413 |
+
'rgb': RGBTransform(imagenet_default_mean_and_std=False),
|
414 |
+
'depth': DepthTransform(standardize_depth=True),
|
415 |
+
'normal': NormalTransform(standardize_surface_normals=False),
|
416 |
+
'mask_valid': MaskTransform(mask_pool_size=1),
|
417 |
+
'semseg_coco': SemsegTransform(shift_idx_by_one=True),
|
418 |
+
'canny_edge': RGBTransform(imagenet_default_mean_and_std=False),
|
419 |
+
'human_poses': HumanPoseTransform(coord_bins=1000, only_pose=True),
|
420 |
+
'sam_mask': SAMInstanceTransform(mask_size=64, max_instance_n=1),
|
421 |
+
}
|
422 |
+
|
423 |
+
MODALITY_TRANSFORMS_VQCONTROLNET = {
|
424 |
+
'rgb': RGBTransform(imagenet_default_mean_and_std=False),
|
425 |
+
'mask_valid': MaskTransform(mask_pool_size=1),
|
426 |
+
'caption': CaptionTransform(aligned_captions=True),
|
427 |
+
}
|
fourm/data/modality_transforms.py
ADDED
@@ -0,0 +1,1387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import gzip
|
15 |
+
import json
|
16 |
+
import random
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Optional, Tuple, List, Dict
|
19 |
+
from abc import ABC, abstractmethod
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
import cv2
|
23 |
+
|
24 |
+
import albumentations as A
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torchvision.transforms.functional as TF
|
28 |
+
import torchvision.transforms as T
|
29 |
+
from einops import rearrange, repeat, reduce
|
30 |
+
|
31 |
+
from fourm.utils import to_2tuple
|
32 |
+
from fourm.utils.data_constants import (IMAGENET_DEFAULT_MEAN,
|
33 |
+
IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN,
|
34 |
+
IMAGENET_SURFACE_NORMAL_STD, IMAGENET_SURFACE_NORMAL_MEAN,
|
35 |
+
IMAGENET_INCEPTION_STD, SEG_IGNORE_INDEX, PAD_MASK_VALUE)
|
36 |
+
|
37 |
+
|
38 |
+
# The @-symbol is used to specify the resolution of a modality. Syntax: modality@resolution
|
39 |
+
def get_transform_key(mod_name):
|
40 |
+
return mod_name.split('@')[0]
|
41 |
+
|
42 |
+
def get_transform_resolution(mod_name, default_resolution, to_tuple=True):
|
43 |
+
res = int(mod_name.split('@')[1]) if '@' in mod_name else default_resolution
|
44 |
+
return to_2tuple(res) if to_tuple else res
|
45 |
+
|
46 |
+
def get_transform(mod_name, transforms_dict):
|
47 |
+
return transforms_dict.get(get_transform_key(mod_name), IdentityTransform())
|
48 |
+
|
49 |
+
def get_pil_resample_mode(resample_mode: str):
|
50 |
+
"""
|
51 |
+
Returns the PIL resampling mode for the given resample mode string.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
resample_mode: Resampling mode string
|
55 |
+
"""
|
56 |
+
if resample_mode is None:
|
57 |
+
return None
|
58 |
+
elif resample_mode == "bilinear":
|
59 |
+
return Image.Resampling.BILINEAR if hasattr(Image, 'Resampling') else Image.BILINEAR
|
60 |
+
elif resample_mode == "bicubic":
|
61 |
+
return Image.Resampling.BICUBIC if hasattr(Image, 'Resampling') else Image.BICUBIC
|
62 |
+
elif resample_mode == "nearest":
|
63 |
+
return Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Resample mode {resample_mode} is not supported.")
|
66 |
+
|
67 |
+
class UnifiedDataTransform(object):
|
68 |
+
def __init__(self, transforms_dict, image_augmenter, resample_mode: str = None, add_sizes: bool = False, **kwargs):
|
69 |
+
"""Unified data augmentation for FourM
|
70 |
+
|
71 |
+
Args:
|
72 |
+
transforms_dict (dict): Dict of transforms for each modality
|
73 |
+
image_augmenter (AbstractImageAugmenter): Image augmenter
|
74 |
+
resample_mode (str, optional): Resampling mode for PIL images (default: None -> uses default resampling mode for data type)
|
75 |
+
One out of ["bilinear", "bicubic", "nearest", None].
|
76 |
+
add_sizes (bool, optional): Whether to add crop coordinates and original size to the output dict
|
77 |
+
"""
|
78 |
+
|
79 |
+
self.transforms_dict = transforms_dict
|
80 |
+
self.image_augmenter = image_augmenter
|
81 |
+
self.resample_mode = resample_mode
|
82 |
+
self.add_sizes = add_sizes
|
83 |
+
|
84 |
+
def unified_image_augment(self, mod_dict, crop_settings):
|
85 |
+
"""Apply the image augmenter to all modalities where it is applicable
|
86 |
+
|
87 |
+
Args:
|
88 |
+
mod_dict (dict): Dict of modalities
|
89 |
+
crop_settings (dict): Crop settings
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
dict: Transformed dict of modalities
|
93 |
+
"""
|
94 |
+
|
95 |
+
crop_coords, flip, orig_size, target_size, rand_aug_idx = self.image_augmenter(mod_dict, crop_settings)
|
96 |
+
|
97 |
+
mod_dict = {
|
98 |
+
k: self.transforms_dict[get_transform_key(k)].image_augment(
|
99 |
+
v, crop_coords=crop_coords, flip=flip, orig_size=orig_size,
|
100 |
+
target_size=get_transform_resolution(k, target_size), rand_aug_idx=rand_aug_idx,
|
101 |
+
resample_mode=self.resample_mode
|
102 |
+
)
|
103 |
+
for k, v in mod_dict.items()
|
104 |
+
}
|
105 |
+
|
106 |
+
if self.add_sizes:
|
107 |
+
mod_dict["crop_coords"] = torch.tensor(crop_coords)
|
108 |
+
mod_dict["orig_size"] = torch.tensor(orig_size)
|
109 |
+
|
110 |
+
return mod_dict
|
111 |
+
|
112 |
+
def __call__(self, mod_dict):
|
113 |
+
"""Apply the augmentation to a dict of modalities (both image based and sequence based modalities)
|
114 |
+
|
115 |
+
Args:
|
116 |
+
mod_dict (dict): Dict of modalities
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
dict: Transformed dict of modalities
|
120 |
+
"""
|
121 |
+
crop_settings = mod_dict.pop("crop_settings", None)
|
122 |
+
|
123 |
+
mod_dict = {k: get_transform(k, self.transforms_dict).preprocess(v) for k, v in mod_dict.items()}
|
124 |
+
|
125 |
+
mod_dict = self.unified_image_augment(mod_dict, crop_settings)
|
126 |
+
|
127 |
+
mod_dict = {k: get_transform(k, self.transforms_dict).postprocess(v) for k, v in mod_dict.items()}
|
128 |
+
|
129 |
+
return mod_dict
|
130 |
+
|
131 |
+
def __repr__(self):
|
132 |
+
repr = "(UnifiedDataAugmentation,\n"
|
133 |
+
repr += ")"
|
134 |
+
return repr
|
135 |
+
|
136 |
+
|
137 |
+
class AbstractTransform(ABC):
|
138 |
+
|
139 |
+
@abstractmethod
|
140 |
+
def load(self, sample):
|
141 |
+
pass
|
142 |
+
|
143 |
+
@abstractmethod
|
144 |
+
def preprocess(self, sample):
|
145 |
+
pass
|
146 |
+
|
147 |
+
@abstractmethod
|
148 |
+
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
149 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
150 |
+
pass
|
151 |
+
|
152 |
+
@abstractmethod
|
153 |
+
def postprocess(self, v):
|
154 |
+
pass
|
155 |
+
|
156 |
+
|
157 |
+
class ImageTransform(AbstractTransform):
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def pil_loader(path: str) -> Image.Image:
|
161 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
162 |
+
# with open(path, 'rb') as f:
|
163 |
+
# img = Image.open(f)
|
164 |
+
img = Image.open(path)
|
165 |
+
return img
|
166 |
+
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def image_hflip(img: Image, flip: bool):
|
170 |
+
"""Crop and resize an image
|
171 |
+
|
172 |
+
:param img: Image to crop and resize
|
173 |
+
:param flip: Whether to flip the image
|
174 |
+
:return: Flipped image (if flip = True)
|
175 |
+
"""
|
176 |
+
if flip:
|
177 |
+
img = TF.hflip(img)
|
178 |
+
return img
|
179 |
+
|
180 |
+
@staticmethod
|
181 |
+
def image_crop_and_resize(img: Image, crop_coords: Tuple, target_size: Tuple, resample_mode: str = None):
|
182 |
+
"""Crop and resize an image
|
183 |
+
|
184 |
+
:param img: Image to crop and resize
|
185 |
+
:param crop_coords: Coordinates of the crop (top, left, h, w)
|
186 |
+
:param target_size: Coordinates of the resize (height, width)
|
187 |
+
:return: Cropped and resized image
|
188 |
+
"""
|
189 |
+
|
190 |
+
top, left, h, w = crop_coords
|
191 |
+
resize_height, resize_width = target_size
|
192 |
+
img = TF.crop(img, top, left, h, w)
|
193 |
+
resample_mode = get_pil_resample_mode(resample_mode)
|
194 |
+
img = img.resize((resize_height, resize_width), resample=resample_mode)
|
195 |
+
return img
|
196 |
+
|
197 |
+
|
198 |
+
class RGBTransform(ImageTransform):
|
199 |
+
|
200 |
+
def __init__(self, imagenet_default_mean_and_std=True, color_jitter=False, color_jitter_strength=0.5):
|
201 |
+
self.rgb_mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
|
202 |
+
self.rgb_std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
|
203 |
+
self.color_jitter = color_jitter
|
204 |
+
self.color_jitter_transform = self.random_color_jitter(color_jitter_strength)
|
205 |
+
|
206 |
+
def random_color_jitter(self, strength=0.5):
|
207 |
+
# Color Jitter from Pix2Seq and SimCLR
|
208 |
+
# Source: https://github.com/google-research/pix2seq/blob/main/data/data_utils.py#L114
|
209 |
+
t = T.Compose([
|
210 |
+
T.RandomApply([T.ColorJitter(brightness=0.8 * strength, contrast=0.8 * strength, saturation=0.8 * strength, hue=0.2 * strength)], p=0.8),
|
211 |
+
T.RandomApply([T.Grayscale(num_output_channels=3)], p=0.2),
|
212 |
+
])
|
213 |
+
|
214 |
+
return t
|
215 |
+
|
216 |
+
def rgb_to_tensor(self, img):
|
217 |
+
img = TF.to_tensor(img)
|
218 |
+
img = TF.normalize(img, mean=self.rgb_mean, std=self.rgb_std)
|
219 |
+
return img
|
220 |
+
|
221 |
+
def load(self, path):
|
222 |
+
# TODO: Instead of converting to RGB here, do it either in the preprocess or the postprocess step. Makes it compatible with wds dataloading.
|
223 |
+
sample = self.pil_loader(path)
|
224 |
+
return sample
|
225 |
+
|
226 |
+
def preprocess(self, sample):
|
227 |
+
sample = sample.convert('RGB')
|
228 |
+
|
229 |
+
if self.color_jitter:
|
230 |
+
sample = self.color_jitter_transform(sample)
|
231 |
+
|
232 |
+
return sample
|
233 |
+
|
234 |
+
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
235 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
236 |
+
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode)
|
237 |
+
img = self.image_hflip(img, flip)
|
238 |
+
return img
|
239 |
+
|
240 |
+
def postprocess(self, sample):
|
241 |
+
sample = self.rgb_to_tensor(sample)
|
242 |
+
return sample
|
243 |
+
|
244 |
+
|
245 |
+
class DepthTransform(ImageTransform):
|
246 |
+
|
247 |
+
def __init__(self, standardize_depth=True):
|
248 |
+
self.standardize_depth = standardize_depth
|
249 |
+
|
250 |
+
def depth_to_tensor(self, img):
|
251 |
+
img = torch.Tensor( img / (2 ** 16 - 1.0) )
|
252 |
+
img = img.unsqueeze(0) # 1 x H x W
|
253 |
+
if self.standardize_depth:
|
254 |
+
img = self.truncated_depth_standardization(img)
|
255 |
+
return img
|
256 |
+
|
257 |
+
@staticmethod
|
258 |
+
def truncated_depth_standardization(depth, thresh: float = 0.1):
|
259 |
+
"""Truncated depth standardization
|
260 |
+
|
261 |
+
:param depth: Depth map
|
262 |
+
:param thresh: Threshold
|
263 |
+
:return: Robustly standardized depth map
|
264 |
+
"""
|
265 |
+
# Flatten depth and remove bottom and top 10% of values
|
266 |
+
trunc_depth = torch.sort(depth.reshape(-1), dim=0)[0]
|
267 |
+
trunc_depth = trunc_depth[int(thresh * trunc_depth.shape[0]): int((1 - thresh) * trunc_depth.shape[0])]
|
268 |
+
return (depth - trunc_depth.mean()) / torch.sqrt(trunc_depth.var() + 1e-6)
|
269 |
+
|
270 |
+
def load(self, path):
|
271 |
+
sample = self.pil_loader(path)
|
272 |
+
return sample
|
273 |
+
|
274 |
+
def preprocess(self, sample):
|
275 |
+
return sample
|
276 |
+
|
277 |
+
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
278 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
279 |
+
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode)
|
280 |
+
img = self.image_hflip(img, flip)
|
281 |
+
return img
|
282 |
+
|
283 |
+
def postprocess(self, sample):
|
284 |
+
sample = np.array(sample)
|
285 |
+
sample = self.depth_to_tensor(sample)
|
286 |
+
return sample
|
287 |
+
|
288 |
+
|
289 |
+
class NormalTransform(ImageTransform):
|
290 |
+
|
291 |
+
def __init__(self, standardize_surface_normals=False):
|
292 |
+
self.normal_mean = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_MEAN
|
293 |
+
self.normal_std = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_STD
|
294 |
+
|
295 |
+
def normal_to_tensor(self, img):
|
296 |
+
img = TF.to_tensor(img)
|
297 |
+
img = TF.normalize(img, mean=self.normal_mean, std=self.normal_std)
|
298 |
+
return img
|
299 |
+
|
300 |
+
def load(self, path):
|
301 |
+
sample = self.pil_loader(path)
|
302 |
+
return sample
|
303 |
+
|
304 |
+
def preprocess(self, sample):
|
305 |
+
return sample
|
306 |
+
|
307 |
+
def image_hflip(self, img: Image, flip: bool):
|
308 |
+
if flip:
|
309 |
+
img = TF.hflip(img)
|
310 |
+
flipped_np = np.array(img)
|
311 |
+
flipped_np[:, :, 0] = 255 - flipped_np[:, :, 0]
|
312 |
+
img = Image.fromarray(flipped_np)
|
313 |
+
|
314 |
+
return img
|
315 |
+
|
316 |
+
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
317 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
318 |
+
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode)
|
319 |
+
img = self.image_hflip(img, flip)
|
320 |
+
return img
|
321 |
+
|
322 |
+
def postprocess(self, sample):
|
323 |
+
sample = self.normal_to_tensor(sample)
|
324 |
+
return sample
|
325 |
+
|
326 |
+
|
327 |
+
class SemsegTransform(ImageTransform):
|
328 |
+
|
329 |
+
def __init__(self, scale_factor=1.0, shift_idx_by_one=False, id_mapping: Optional[Dict] = None, select_channel=None):
|
330 |
+
self.scale_factor = scale_factor
|
331 |
+
self.shift_idx_by_one = shift_idx_by_one
|
332 |
+
self.id_mapping = id_mapping
|
333 |
+
self.select_channel = select_channel
|
334 |
+
|
335 |
+
def map_semseg_values(self, sample):
|
336 |
+
sample = np.asarray(sample)
|
337 |
+
mapping_fn = lambda x: self.id_mapping.get(x, x)
|
338 |
+
sample = np.vectorize(mapping_fn)(sample)
|
339 |
+
sample = Image.fromarray(sample, mode='P')
|
340 |
+
return sample
|
341 |
+
|
342 |
+
def semseg_to_tensor(self, img):
|
343 |
+
# Rescale to scale factor
|
344 |
+
if self.scale_factor != 1.0:
|
345 |
+
target_height, target_width = int(img.height * self.scale_factor), int(img.width * self.scale_factor)
|
346 |
+
img = img.resize((target_width, target_height))
|
347 |
+
# Using pil_to_tensor keeps it in uint8, to_tensor converts it to float (rescaled to [0, 1])
|
348 |
+
img = TF.pil_to_tensor(img).to(torch.long).squeeze(0)
|
349 |
+
# 255->0, 254->0, all else shifted up by one
|
350 |
+
return img
|
351 |
+
|
352 |
+
def load(self, path):
|
353 |
+
sample = self.pil_loader(path)
|
354 |
+
if self.select_channel is not None:
|
355 |
+
sample = sample.split()[self.select_channel]
|
356 |
+
return sample
|
357 |
+
|
358 |
+
def preprocess(self, sample):
|
359 |
+
sample = sample.convert('P')
|
360 |
+
|
361 |
+
if self.id_mapping is not None:
|
362 |
+
sample = self.map_semseg_values(sample)
|
363 |
+
|
364 |
+
if self.shift_idx_by_one:
|
365 |
+
sample = np.asarray(sample)
|
366 |
+
sample = sample + 1
|
367 |
+
sample = Image.fromarray(sample, mode='P')
|
368 |
+
|
369 |
+
return sample
|
370 |
+
|
371 |
+
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
372 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
373 |
+
# Value for padding with TF.crop is always 0.
|
374 |
+
# Override resampling mode to 'nearest' for semseg
|
375 |
+
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest')
|
376 |
+
img = self.image_hflip(img, flip)
|
377 |
+
return img
|
378 |
+
|
379 |
+
def postprocess(self, sample):
|
380 |
+
img = self.semseg_to_tensor(sample)
|
381 |
+
return img
|
382 |
+
|
383 |
+
|
384 |
+
class SAMInstanceTransform(AbstractTransform):
|
385 |
+
|
386 |
+
def __init__(self, mask_size=64, max_instance_n=20, bbox_area_threshold=0.0005):
|
387 |
+
self.mask_size = mask_size
|
388 |
+
self.max_instance_n = max_instance_n
|
389 |
+
self.bbox_area_threshold = bbox_area_threshold
|
390 |
+
|
391 |
+
def get_bbox(self, instance):
|
392 |
+
""" Gets bounding box of the given instance
|
393 |
+
"""
|
394 |
+
min_h, max_h = instance[:,:,1].min(), instance[:,:,1].max()
|
395 |
+
min_w, max_w = instance[:,:,0].min(), instance[:,:,0].max()
|
396 |
+
return [min_h, min_w, max_h, max_w]
|
397 |
+
|
398 |
+
def extend_instance_points(self, instance, border_fn):
|
399 |
+
""" Given an instance and a border function `border_fn`, extends the instance points with crossing points between the instance and
|
400 |
+
the crop borders. The crossing points are obtained using border_fn.
|
401 |
+
"""
|
402 |
+
p = instance[:,0]
|
403 |
+
p_next = np.roll(p, (-1), axis=(0))
|
404 |
+
final_points = []
|
405 |
+
for x, xn in zip(p, p_next):
|
406 |
+
final_points.append(x)
|
407 |
+
for r in border_fn(x, xn):
|
408 |
+
final_points.append(r.astype(np.int32))
|
409 |
+
p = np.stack(final_points)
|
410 |
+
return p[:,None]
|
411 |
+
|
412 |
+
def remove_redundant_lines(self, orig_instance, instance):
|
413 |
+
""" Removes the redundant lines added during cropping.
|
414 |
+
"""
|
415 |
+
final_points = []
|
416 |
+
for p in instance:
|
417 |
+
distance = cv2.pointPolygonTest(orig_instance, (p[0,0].item(), p[0,1].item()), measureDist=True)
|
418 |
+
if distance >= 0:
|
419 |
+
final_points.append(p[0])
|
420 |
+
return np.stack(final_points)[:,None]
|
421 |
+
|
422 |
+
def get_border_functions(self, crop_points):
|
423 |
+
""" Creates and returns a function `fn` using crop region coordinates given in crop_points.
|
424 |
+
`fn` receives two input points x and xn and returns all the crossing points between the line connecting
|
425 |
+
x and xn, and the borders of the cropping rectangle.
|
426 |
+
"""
|
427 |
+
p = crop_points[:,0]
|
428 |
+
p_next = np.roll(p, (-1), axis=(0))
|
429 |
+
def fn(x, xn):
|
430 |
+
output = []
|
431 |
+
c_diff = p_next - p
|
432 |
+
x_diff = x - xn
|
433 |
+
for diff, c in zip(c_diff, p):
|
434 |
+
A = np.array([
|
435 |
+
[diff[0], x_diff[0]],
|
436 |
+
[diff[1], x_diff[1]]
|
437 |
+
])
|
438 |
+
b = x - c
|
439 |
+
try:
|
440 |
+
lmbda = np.linalg.solve(A, b)
|
441 |
+
if 0 <= lmbda[0] <= 1 and 0 <= lmbda[1] <= 1:
|
442 |
+
output.append(lmbda[1] * xn + (1-lmbda[1]) * x)
|
443 |
+
except:
|
444 |
+
continue
|
445 |
+
return output
|
446 |
+
return fn
|
447 |
+
|
448 |
+
def crop_sample(self, sample, crop_coords):
|
449 |
+
""" Crop the sample using crop coordinates.
|
450 |
+
"""
|
451 |
+
top, left, h, w = crop_coords
|
452 |
+
crop_region = (left, top, left + w, top + h)
|
453 |
+
crop_points = np.array([
|
454 |
+
[crop_region[0], crop_region[1]],
|
455 |
+
[crop_region[2], crop_region[1]],
|
456 |
+
[crop_region[2], crop_region[3]],
|
457 |
+
[crop_region[0], crop_region[3]],
|
458 |
+
])[:,None]
|
459 |
+
border_functions = self.get_border_functions(crop_points)
|
460 |
+
cropped_sample = []
|
461 |
+
for instance in sample:
|
462 |
+
instance = self.extend_instance_points(instance, border_functions)
|
463 |
+
filter_condition = (
|
464 |
+
(instance[:, :, 0] > crop_region[0]) &
|
465 |
+
(instance[:, :, 0] < crop_region[2]) &
|
466 |
+
(instance[:, :, 1] > crop_region[1]) &
|
467 |
+
(instance[:, :, 1] < crop_region[3])
|
468 |
+
)
|
469 |
+
if not np.any(filter_condition):
|
470 |
+
continue
|
471 |
+
|
472 |
+
instance_copy = instance.copy()
|
473 |
+
instance_copy[:, :, 0] = np.clip(instance[:, :, 0], a_min=crop_region[0], a_max=crop_region[2])
|
474 |
+
instance_copy[:, :, 1] = np.clip(instance[:, :, 1], a_min=crop_region[1], a_max=crop_region[3])
|
475 |
+
instance_copy = self.remove_redundant_lines(instance, instance_copy)
|
476 |
+
instance_copy[:, :, 0] -= crop_region[0]
|
477 |
+
instance_copy[:, :, 1] -= crop_region[1]
|
478 |
+
|
479 |
+
cropped_sample.append(instance_copy)
|
480 |
+
return cropped_sample
|
481 |
+
|
482 |
+
def resize_sample(self, sample, original_size, target_size):
|
483 |
+
""" Resize the sample
|
484 |
+
"""
|
485 |
+
width_scale = target_size[1] / original_size[1]
|
486 |
+
height_scale = target_size[0] / original_size[0]
|
487 |
+
resized_sample = []
|
488 |
+
for instance in sample:
|
489 |
+
instance_copy = instance.copy()
|
490 |
+
instance_copy[:, :, 0] = np.round(width_scale * instance_copy[:, :, 0])
|
491 |
+
instance_copy[:, :, 1] = np.round(height_scale * instance_copy[:, :, 1])
|
492 |
+
resized_sample.append(instance_copy)
|
493 |
+
return resized_sample
|
494 |
+
|
495 |
+
def remove_tiny_instances(self, sample, image_size):
|
496 |
+
""" Remove instances that have an area ratio smaller than `bbox_area_threshold`.
|
497 |
+
"""
|
498 |
+
filtered_sample = []
|
499 |
+
for instance in sample:
|
500 |
+
min_h, min_w, max_h, max_w = self.get_bbox(instance)
|
501 |
+
bbox_area_ratio = (max_h - min_h) * (max_w - min_w) / (image_size[0] * image_size[1])
|
502 |
+
if bbox_area_ratio < self.bbox_area_threshold:
|
503 |
+
continue
|
504 |
+
filtered_sample.append(instance)
|
505 |
+
return filtered_sample
|
506 |
+
|
507 |
+
def hflip(self, sample, width):
|
508 |
+
""" Horizontal flipping the instances in a sample.
|
509 |
+
"""
|
510 |
+
flipped_sample = []
|
511 |
+
for instance in sample:
|
512 |
+
instance_copy = instance.copy()
|
513 |
+
instance_copy[:, :, 0] = width - instance_copy[:, :, 0]
|
514 |
+
flipped_sample.append(instance_copy)
|
515 |
+
return flipped_sample
|
516 |
+
|
517 |
+
def get_binary_masks(self, sample):
|
518 |
+
""" Creates the binary mask of each instance in the sample.
|
519 |
+
"""
|
520 |
+
if self.max_instance_n is None:
|
521 |
+
max_instance_n = len(sample)
|
522 |
+
else:
|
523 |
+
max_instance_n = self.max_instance_n
|
524 |
+
masks = np.zeros((max_instance_n, self.mask_size, self.mask_size))
|
525 |
+
bboxes = np.zeros((max_instance_n, 4))
|
526 |
+
valid = np.full(max_instance_n, False)
|
527 |
+
for i, instance in enumerate(sample):
|
528 |
+
bbox = self.get_bbox(instance)
|
529 |
+
min_h, min_w, max_h, max_w = bbox
|
530 |
+
instance_copy = instance.copy()
|
531 |
+
mask = np.zeros((self.mask_size, self.mask_size), dtype=np.uint8)
|
532 |
+
instance_copy[:,:,0] = (instance_copy[:,:,0] - min_w) / (max_w - min_w) * self.mask_size
|
533 |
+
instance_copy[:,:,1] = (instance_copy[:,:,1] - min_h) / (max_h - min_h) * self.mask_size
|
534 |
+
cv2.drawContours(mask, [instance_copy], 0, (255), thickness=cv2.FILLED)
|
535 |
+
masks[i] = mask / 255.0
|
536 |
+
bboxes[i] = np.array(bbox)
|
537 |
+
valid[i] = True
|
538 |
+
return masks, bboxes, valid
|
539 |
+
|
540 |
+
def load(self, path):
|
541 |
+
sample = np.load(path, allow_pickle=True)
|
542 |
+
return sample
|
543 |
+
|
544 |
+
def preprocess(self, sample):
|
545 |
+
if self.max_instance_n is None or len(sample) <= self.max_instance_n:
|
546 |
+
indecies = np.arange(len(sample))
|
547 |
+
else:
|
548 |
+
indecies = np.random.choice(len(sample), size=self.max_instance_n, replace=False)
|
549 |
+
return [p['points'] for i, p in enumerate(sample) if i in indecies]
|
550 |
+
|
551 |
+
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
552 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
553 |
+
v = self.crop_sample(v, crop_coords)
|
554 |
+
_, _, h, w = crop_coords
|
555 |
+
v = self.resize_sample(v, (h, w), target_size)
|
556 |
+
v = self.remove_tiny_instances(v, target_size)
|
557 |
+
if flip:
|
558 |
+
v = self.hflip(v, target_size[0])
|
559 |
+
return v
|
560 |
+
|
561 |
+
def postprocess(self, sample):
|
562 |
+
sample, bboxes, valid = self.get_binary_masks(sample)
|
563 |
+
return {
|
564 |
+
'instance': torch.from_numpy(sample).to(torch.float32),
|
565 |
+
'bbox': torch.from_numpy(bboxes).to(torch.float32),
|
566 |
+
'valid': torch.from_numpy(valid)
|
567 |
+
}
|
568 |
+
|
569 |
+
|
570 |
+
class MaskTransform(ImageTransform):
|
571 |
+
|
572 |
+
def __init__(self, mask_pool_size=1):
|
573 |
+
assert isinstance(mask_pool_size, int)
|
574 |
+
self.mask_pool_size = mask_pool_size # Use to expand masks
|
575 |
+
|
576 |
+
def mask_to_tensor(self, img):
|
577 |
+
mask = TF.to_tensor(img)
|
578 |
+
if self.mask_pool_size > 1:
|
579 |
+
mask = reduce(mask, 'c (h1 h2) (w1 w2) -> c h1 w1', 'min', h2=self.mask_pool_size, w2=self.mask_pool_size)
|
580 |
+
mask = repeat(mask, 'c h1 w1 -> c (h1 h2) (w1 w2)', h2=self.mask_pool_size, w2=self.mask_pool_size)
|
581 |
+
return (mask == 1.0)
|
582 |
+
|
583 |
+
def load(self, path):
|
584 |
+
sample = self.pil_loader(path)
|
585 |
+
return sample
|
586 |
+
|
587 |
+
def preprocess(self, sample):
|
588 |
+
return sample
|
589 |
+
|
590 |
+
def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
591 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
592 |
+
# Override resampling mode to 'nearest' for masks
|
593 |
+
img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest')
|
594 |
+
img = self.image_hflip(img, flip)
|
595 |
+
return img
|
596 |
+
|
597 |
+
def postprocess(self, sample):
|
598 |
+
sample = self.mask_to_tensor(sample)
|
599 |
+
return sample
|
600 |
+
|
601 |
+
|
602 |
+
class TokTransform(AbstractTransform):
|
603 |
+
|
604 |
+
def __init__(self):
|
605 |
+
pass
|
606 |
+
|
607 |
+
def load(self, path):
|
608 |
+
sample = np.load(path).astype(int)
|
609 |
+
return sample
|
610 |
+
|
611 |
+
def preprocess(self, sample):
|
612 |
+
return sample
|
613 |
+
|
614 |
+
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
615 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
616 |
+
if rand_aug_idx is None:
|
617 |
+
raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used")
|
618 |
+
v = torch.tensor(v[rand_aug_idx])
|
619 |
+
return v
|
620 |
+
|
621 |
+
def postprocess(self, sample):
|
622 |
+
return sample
|
623 |
+
|
624 |
+
|
625 |
+
class DetectionTransform(AbstractTransform):
|
626 |
+
|
627 |
+
def __init__(self, det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0, return_raw=False):
|
628 |
+
self.det_threshold = det_threshold
|
629 |
+
self.det_max_instances = det_max_instances
|
630 |
+
self.coord_bins = coord_bins
|
631 |
+
self.min_visibility = min_visibility
|
632 |
+
self.return_raw = return_raw
|
633 |
+
|
634 |
+
if bbox_order == 'area':
|
635 |
+
self.bbox_order = self.order_bboxes_by_area
|
636 |
+
elif bbox_order == 'score':
|
637 |
+
self.bbox_order = self.order_bboxes_by_score
|
638 |
+
elif bbox_order == 'random':
|
639 |
+
self.bbox_order = self.shuffle_bboxes
|
640 |
+
else:
|
641 |
+
self.bbox_order = self.order_bboxes_by_dist_to_orig
|
642 |
+
|
643 |
+
@staticmethod
|
644 |
+
def order_bboxes_by_area(bboxes):
|
645 |
+
return sorted(bboxes, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]), reverse=True)
|
646 |
+
|
647 |
+
@staticmethod
|
648 |
+
def order_bboxes_by_dist_to_orig(bboxes):
|
649 |
+
return sorted(bboxes, key=lambda x: x[0] ** 2 + x[1] ** 2)
|
650 |
+
|
651 |
+
@staticmethod
|
652 |
+
def order_bboxes_by_score(bboxes):
|
653 |
+
return sorted(bboxes, key=lambda x: x[5], reverse=True)
|
654 |
+
|
655 |
+
@staticmethod
|
656 |
+
def shuffle_bboxes(bboxes):
|
657 |
+
return sorted(bboxes, key=lambda x: random.random())
|
658 |
+
|
659 |
+
def convert_detection_instance(self, instances):
|
660 |
+
"""Convert instances dict to list of lists where each list takes the form:
|
661 |
+
[xmin, ymin, xmax, ymax, class_name, score]
|
662 |
+
"""
|
663 |
+
|
664 |
+
instances = [inst['boxes'] + [inst['class_name'], inst['score']] for inst in instances if inst['score'] >= self.det_threshold]
|
665 |
+
return instances
|
666 |
+
|
667 |
+
def bboxes_hflip(self, bboxes: List[Tuple], image_size: Tuple, flip: bool):
|
668 |
+
image_height, image_width = image_size
|
669 |
+
if flip:
|
670 |
+
bboxes = [tuple(A.bbox_hflip(bbox[:4], rows=image_height, cols=image_width)) + tuple(bbox[4:])
|
671 |
+
for bbox in bboxes]
|
672 |
+
|
673 |
+
return bboxes
|
674 |
+
|
675 |
+
def bboxes_crop_and_resize(self, bboxes: List[Tuple], crop_coords: Tuple, orig_size: Tuple):
|
676 |
+
"""Crop and resize bounding boxes
|
677 |
+
|
678 |
+
Args:
|
679 |
+
bboxes: Bounding boxes to crop and resize
|
680 |
+
crop_coords: Coordinates of the crop (top, left, h, w)
|
681 |
+
orig_size: Size of the original image
|
682 |
+
|
683 |
+
Returns:
|
684 |
+
Cropped and resized bounding boxes
|
685 |
+
"""
|
686 |
+
orig_height, orig_width = orig_size
|
687 |
+
top, left, h, w = crop_coords
|
688 |
+
xmin, ymin, xmax, ymax = left, top, left + w, top + h
|
689 |
+
bboxes = [tuple(A.bbox_crop(bbox[:4], x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height,
|
690 |
+
cols=orig_width)) + tuple(bbox[4:])
|
691 |
+
for bbox in bboxes]
|
692 |
+
bboxes = A.core.bbox_utils.filter_bboxes(bboxes, rows=h, cols=w, min_visibility=self.min_visibility)
|
693 |
+
# No need to resize, bounding boxes in albumentations format are scale invariant
|
694 |
+
|
695 |
+
return bboxes
|
696 |
+
|
697 |
+
def order_and_filter_bboxes(self, bboxes):
|
698 |
+
if self.det_max_instances is not None and len(bboxes) > self.det_max_instances:
|
699 |
+
bboxes = self.order_bboxes_by_score(bboxes)[:self.det_max_instances]
|
700 |
+
|
701 |
+
return self.bbox_order(bboxes)
|
702 |
+
|
703 |
+
def convert_bboxes_to_string(self, bboxes: List[Tuple]):
|
704 |
+
"""Convert bounding boxes to a string.
|
705 |
+
xmin, ymin, xmax, ymax are mapped to v0, v1, v2, v3 special tokens.
|
706 |
+
|
707 |
+
Args:
|
708 |
+
bboxes: Bounding boxes
|
709 |
+
|
710 |
+
Returns:
|
711 |
+
String representation of the bounding boxes
|
712 |
+
"""
|
713 |
+
# Remove score, quantize coordinates
|
714 |
+
bins = self.coord_bins
|
715 |
+
|
716 |
+
bboxes = [
|
717 |
+
[
|
718 |
+
f"v0={round(xmin * (bins - 1))}",
|
719 |
+
f"v1={round(ymin * (bins - 1))}",
|
720 |
+
f"v2={round(xmax * (bins - 1))}",
|
721 |
+
f"v3={round(ymax * (bins - 1))}",
|
722 |
+
cls,
|
723 |
+
]
|
724 |
+
for (xmin, ymin, xmax, ymax, cls, score) in bboxes
|
725 |
+
]
|
726 |
+
# Convert each bounding box to a string
|
727 |
+
bboxes = [' '.join(b) for b in bboxes]
|
728 |
+
# Convert the list to a str
|
729 |
+
return ' '.join(bboxes)
|
730 |
+
|
731 |
+
def load(self, path):
|
732 |
+
with open(path, 'r') as f:
|
733 |
+
sample = json.load(f)
|
734 |
+
|
735 |
+
return sample
|
736 |
+
|
737 |
+
def preprocess(self, sample):
|
738 |
+
instances = sample['instances']
|
739 |
+
return self.convert_detection_instance(instances)
|
740 |
+
|
741 |
+
def image_augment(self, bboxes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
742 |
+
rand_aug_idx=None, resample_mode: str = None):
|
743 |
+
bboxes = self.bboxes_crop_and_resize(bboxes, crop_coords, orig_size)
|
744 |
+
bboxes = self.bboxes_hflip(bboxes, target_size, flip)
|
745 |
+
bboxes = self.order_and_filter_bboxes(bboxes)
|
746 |
+
return bboxes
|
747 |
+
|
748 |
+
def postprocess(self, bboxes):
|
749 |
+
if self.return_raw:
|
750 |
+
return bboxes
|
751 |
+
bboxes = self.convert_bboxes_to_string(bboxes)
|
752 |
+
return bboxes
|
753 |
+
|
754 |
+
|
755 |
+
class CaptionTransform(AbstractTransform):
|
756 |
+
|
757 |
+
def __init__(self, aligned_captions=True, no_aug=False):
|
758 |
+
self.aligned_captions = aligned_captions
|
759 |
+
self.no_aug = no_aug
|
760 |
+
|
761 |
+
def load(self, path):
|
762 |
+
# Caption can either be stored as .txt or .json.gz (in which case it's a list of dicts)
|
763 |
+
if path.endswith('.txt'):
|
764 |
+
sample = Path(path).read_text()
|
765 |
+
elif path.endswith('.json'):
|
766 |
+
with open(path, 'r') as f:
|
767 |
+
sample = json.load(f)
|
768 |
+
elif path.endswith('.json.gz'):
|
769 |
+
with gzip.open(path, 'rb') as f:
|
770 |
+
sample = json.load(f)
|
771 |
+
return sample
|
772 |
+
|
773 |
+
def preprocess(self, sample):
|
774 |
+
return sample
|
775 |
+
|
776 |
+
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
777 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
778 |
+
|
779 |
+
if isinstance(val, list) or isinstance(val, tuple):
|
780 |
+
if self.aligned_captions:
|
781 |
+
val = val[0] if rand_aug_idx is None else val[rand_aug_idx]
|
782 |
+
else:
|
783 |
+
val = random.choice(val) if not self.no_aug else val[0]
|
784 |
+
|
785 |
+
if isinstance(val, dict):
|
786 |
+
# If each caption is saved as a dict, extract the string
|
787 |
+
val = val["caption"]
|
788 |
+
assert isinstance(val, str)
|
789 |
+
|
790 |
+
return val
|
791 |
+
|
792 |
+
def postprocess(self, sample):
|
793 |
+
return sample
|
794 |
+
|
795 |
+
|
796 |
+
class CaptionEmbTransform(AbstractTransform):
|
797 |
+
|
798 |
+
def __init__(self, aligned_captions=True, no_aug=False):
|
799 |
+
self.aligned_captions = aligned_captions
|
800 |
+
self.no_aug = no_aug
|
801 |
+
|
802 |
+
def load(self, path):
|
803 |
+
if path.endswith('.npz'):
|
804 |
+
sample = np.load(path)
|
805 |
+
sample = {'emb': sample['emb'], 'mask_valid': sample['mask_valid']}
|
806 |
+
else:
|
807 |
+
raise ValueError(f"Invalid file format for caption embedding: {path}")
|
808 |
+
return sample
|
809 |
+
|
810 |
+
def preprocess(self, sample):
|
811 |
+
return sample
|
812 |
+
|
813 |
+
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
814 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
815 |
+
|
816 |
+
emb = val['emb']
|
817 |
+
mask_valid = val['mask_valid'].astype(bool)
|
818 |
+
num_sequences = emb.shape[0]
|
819 |
+
|
820 |
+
if num_sequences > 1:
|
821 |
+
if self.aligned_captions:
|
822 |
+
if rand_aug_idx is None:
|
823 |
+
emb, mask_valid = emb[0], mask_valid[0]
|
824 |
+
else:
|
825 |
+
emb, mask_valid = emb[rand_aug_idx], mask_valid[rand_aug_idx]
|
826 |
+
else:
|
827 |
+
if self.no_aug:
|
828 |
+
emb, mask_valid = emb[0], mask_valid[0]
|
829 |
+
else:
|
830 |
+
rand_idx = random.randint(0, num_sequences - 1)
|
831 |
+
emb, mask_valid = emb[rand_idx], mask_valid[rand_idx]
|
832 |
+
else:
|
833 |
+
emb, mask_valid = emb[0], mask_valid[0]
|
834 |
+
|
835 |
+
emb = emb[mask_valid] # Keep only valid embeddings
|
836 |
+
|
837 |
+
return emb
|
838 |
+
|
839 |
+
def postprocess(self, sample):
|
840 |
+
return torch.tensor(sample)
|
841 |
+
|
842 |
+
|
843 |
+
class MetadataTransform(AbstractTransform):
|
844 |
+
|
845 |
+
def __init__(self,
|
846 |
+
special_vmin: int = 0,
|
847 |
+
special_vmax: int = 999,
|
848 |
+
shuffle: bool = True,
|
849 |
+
random_trunc: bool = False,
|
850 |
+
return_chunks: bool = True,
|
851 |
+
return_raw: bool = False,
|
852 |
+
image_dim_bin_size: int = 32,):
|
853 |
+
"""Metadata transform that takes in a metadata dictionary and converts
|
854 |
+
it into a string, or list of strings (for chunked span masking).
|
855 |
+
Uses special tokens v1 to denote metadata types, and v0 for their values.
|
856 |
+
|
857 |
+
Args:
|
858 |
+
special_vmin: Minimum value for special tokens
|
859 |
+
special_vmax: Maximum value for special tokens
|
860 |
+
shuffle: Whether to shuffle the metadata order
|
861 |
+
random_trunc: Whether to randomly truncate the returned metadata
|
862 |
+
return_chunks: Whether to return a list of strings (for chunked span masking),
|
863 |
+
or a single string with all metadata concatenated
|
864 |
+
return_raw: Whether to return the raw metadata dictionary
|
865 |
+
"""
|
866 |
+
self.special_vmin = special_vmin
|
867 |
+
self.special_vmax = special_vmax
|
868 |
+
self.shuffle = shuffle
|
869 |
+
self.random_trunc = random_trunc
|
870 |
+
self.return_chunks = return_chunks
|
871 |
+
self.return_raw = return_raw
|
872 |
+
self.image_dim_bin_size = image_dim_bin_size
|
873 |
+
|
874 |
+
# Explicit map to make sure that additional entries do not change existing IDs
|
875 |
+
# TODO: Make this work with other text tokenizers
|
876 |
+
self.metadata_id_map = {
|
877 |
+
'original_width': 'v1=0',
|
878 |
+
'original_height': 'v1=1',
|
879 |
+
'caption_n_chars': 'v1=2',
|
880 |
+
'caption_n_words': 'v1=3',
|
881 |
+
'caption_n_sentences': 'v1=4',
|
882 |
+
'n_humans': 'v1=5',
|
883 |
+
'n_sam_instances': 'v1=6',
|
884 |
+
'n_coco_instances': 'v1=7',
|
885 |
+
'coco_instance_diversity': 'v1=8',
|
886 |
+
'colorfulness': 'v1=9',
|
887 |
+
'brightness': 'v1=10',
|
888 |
+
'contrast': 'v1=11',
|
889 |
+
'saturation': 'v1=12',
|
890 |
+
'entropy': 'v1=13',
|
891 |
+
'walkability': 'v1=14',
|
892 |
+
'objectness': 'v1=15',
|
893 |
+
'semantic_diversity': 'v1=16',
|
894 |
+
'geometric_complexity': 'v1=17',
|
895 |
+
'occlusion_score': 'v1=18',
|
896 |
+
'watermark_score': 'v1=19',
|
897 |
+
'aesthetic_score': 'v1=20',
|
898 |
+
}
|
899 |
+
self.id_metadata_map = {v: k for k, v in self.metadata_id_map.items()}
|
900 |
+
|
901 |
+
# Image-dimension modalities are binned into 32 bins
|
902 |
+
self.image_dim_modalities = ['original_height', 'original_width']
|
903 |
+
|
904 |
+
# Integer modalities that don't undergo any scaling (except for truncation)
|
905 |
+
self.metadata_int_modalities = [
|
906 |
+
'caption_n_chars', 'caption_n_words', 'caption_n_sentences',
|
907 |
+
'n_humans', 'n_sam_instances', 'n_coco_instances',
|
908 |
+
'coco_instance_diversity', 'semantic_diversity',
|
909 |
+
]
|
910 |
+
|
911 |
+
# Bin boundaries for manually defined metadata modalities.
|
912 |
+
# Lowest and highest bin boundaries are implicitly set to -inf and +inf
|
913 |
+
self.metadata_manual_bins = {
|
914 |
+
'watermark_score': [0.5],
|
915 |
+
'aesthetic_score': [4.5, 5.5],
|
916 |
+
}
|
917 |
+
|
918 |
+
# All other float or integer modalities that are binned into a defined number of bins
|
919 |
+
# Dictionary entries are (vmin, vmax, num_bins)
|
920 |
+
self.metadata_min_max_bins = {
|
921 |
+
'colorfulness': (0, 150, 50),
|
922 |
+
'brightness': (0, 255, 50),
|
923 |
+
'contrast': (0, 127, 50),
|
924 |
+
'saturation': (0, 255, 50),
|
925 |
+
'entropy': (0, 10, 50),
|
926 |
+
'walkability': (0, 1, 50),
|
927 |
+
'objectness': (0, 1, 50),
|
928 |
+
'geometric_complexity': (0, 0.75, 50),
|
929 |
+
'occlusion_score': (0, 0.25, 50),
|
930 |
+
}
|
931 |
+
|
932 |
+
def image_dim_to_string(self, metadata, key, bin_size=32):
|
933 |
+
value = metadata[key] // bin_size
|
934 |
+
value = max(self.special_vmin, min(value, self.special_vmax))
|
935 |
+
return f"{self.metadata_id_map[key]} v0={value}"
|
936 |
+
|
937 |
+
def int_metadata_to_string(self, metadata, key):
|
938 |
+
value = max(self.special_vmin, min(metadata[key], self.special_vmax))
|
939 |
+
return f"{self.metadata_id_map[key]} v0={value}"
|
940 |
+
|
941 |
+
def float_metadata_to_string(self, metadata, key, vmin, vmax, bins):
|
942 |
+
value = max(vmin, min(metadata[key], vmax))
|
943 |
+
value = (value - vmin) / (vmax - vmin)
|
944 |
+
value = int(value * (bins-1))
|
945 |
+
return f"{self.metadata_id_map[key]} v0={value}"
|
946 |
+
|
947 |
+
def manual_bin_metadata_to_string(self, metadata, key):
|
948 |
+
value = metadata[key]
|
949 |
+
bin_idx = 0
|
950 |
+
for bin_value in self.metadata_manual_bins[key]:
|
951 |
+
if value < bin_value:
|
952 |
+
break
|
953 |
+
bin_idx += 1
|
954 |
+
return f"{self.metadata_id_map[key]} v0={bin_idx}"
|
955 |
+
|
956 |
+
def metadata_to_string(self, metadata, keys: List[str] = None):
|
957 |
+
keys = list(metadata.keys()) if keys is None else keys
|
958 |
+
|
959 |
+
if self.shuffle:
|
960 |
+
# Randomly shuffle
|
961 |
+
random.shuffle(keys)
|
962 |
+
if self.random_trunc:
|
963 |
+
# Randomly truncate
|
964 |
+
keys = keys[:random.randint(1,len(keys))]
|
965 |
+
|
966 |
+
metadata_strings = []
|
967 |
+
|
968 |
+
for key in keys:
|
969 |
+
if key in self.image_dim_modalities:
|
970 |
+
# Image dimension modalities
|
971 |
+
metadata_str = self.image_dim_to_string(metadata, key, bin_size=self.image_dim_bin_size)
|
972 |
+
elif key in self.metadata_int_modalities:
|
973 |
+
# Integer modalities that don't undergo any scaling
|
974 |
+
metadata_str = self.int_metadata_to_string(metadata, key)
|
975 |
+
elif key in self.metadata_manual_bins:
|
976 |
+
# Metadata modalities for which bin boundaries are manually defined
|
977 |
+
metadata_str = self.manual_bin_metadata_to_string(metadata, key)
|
978 |
+
else:
|
979 |
+
# All other modalities
|
980 |
+
vmin, vmax, bins = self.metadata_min_max_bins[key]
|
981 |
+
metadata_str = self.float_metadata_to_string(metadata, key, vmin, vmax, bins)
|
982 |
+
|
983 |
+
metadata_strings.append(metadata_str)
|
984 |
+
|
985 |
+
if self.return_chunks:
|
986 |
+
return metadata_strings
|
987 |
+
else:
|
988 |
+
return ' '.join(metadata_strings)
|
989 |
+
|
990 |
+
def load(self, path):
|
991 |
+
with open(path, 'r') as f:
|
992 |
+
sample = json.load(f)
|
993 |
+
return sample
|
994 |
+
|
995 |
+
def preprocess(self, sample):
|
996 |
+
return sample
|
997 |
+
|
998 |
+
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
999 |
+
rand_aug_idx=None, resample_mode: str = None):
|
1000 |
+
return val
|
1001 |
+
|
1002 |
+
def postprocess(self, metadata):
|
1003 |
+
if self.return_raw:
|
1004 |
+
return metadata
|
1005 |
+
metadata_str = self.metadata_to_string(metadata)
|
1006 |
+
return metadata_str
|
1007 |
+
|
1008 |
+
|
1009 |
+
class HumanPoseTransform(AbstractTransform):
|
1010 |
+
|
1011 |
+
def __init__(self, coord_bins=1000, only_pose=False, return_raw=False):
|
1012 |
+
self.coord_bins = coord_bins
|
1013 |
+
self.return_raw = return_raw
|
1014 |
+
self.only_pose = only_pose
|
1015 |
+
|
1016 |
+
def convert_humanpose_instance(self, instances, only_pose=False):
|
1017 |
+
"""Convert instances dict to list of lists where each list takes the form:
|
1018 |
+
[human, xmin xmax ymin ymax global val1 val2 ... val10 pose val1 val2 ... val 207 shape val1 val2 ... val10 camera val1 val2 val3 val4]
|
1019 |
+
Like for bounding boxes, xmin, ymin, xmax, and ymax map to v0, v1, v2, and v3 respectively.
|
1020 |
+
"""
|
1021 |
+
if only_pose: # used for tokenizer training for pose
|
1022 |
+
if len(instances) == 0:
|
1023 |
+
return torch.zeros(207)
|
1024 |
+
else:
|
1025 |
+
return torch.from_numpy(np.array(instances['pred_smpl_params']['body_pose'][0]).flatten()).float()
|
1026 |
+
if len(instances) == 0: #empty, i.e. there are no humans
|
1027 |
+
return 'none'
|
1028 |
+
|
1029 |
+
for k in instances:
|
1030 |
+
if k!='pred_smpl_params':
|
1031 |
+
instances[k] = torch.from_numpy(np.array(instances[k]))
|
1032 |
+
|
1033 |
+
smpl_params = (instances['pred_smpl_params'])
|
1034 |
+
|
1035 |
+
for k in smpl_params:
|
1036 |
+
smpl_params[k] = torch.from_numpy(np.array(smpl_params[k]))
|
1037 |
+
|
1038 |
+
total_num_instances = len(instances['bbox_xyxy'])
|
1039 |
+
instances_converted = []
|
1040 |
+
for ii in range(total_num_instances):
|
1041 |
+
instances_converted.append(['human'] + (np.array(instances['bbox_xyxy'][ii]).flatten().tolist()) + ['global'] + (np.array(instances['pred_smpl_params']['global_orient'][ii]).flatten().tolist()) + ['pose'] + (instances['pose_tokenized'][ii].flatten().tolist()) + ['shape'] + (instances['pred_smpl_params']['betas'][ii].flatten().tolist()) + ['camera'] + (instances['pred_cam'][ii].flatten().tolist()))
|
1042 |
+
return instances_converted
|
1043 |
+
|
1044 |
+
def humanposes_crop_and_resize(self, humanposes: List[Tuple], crop_coords: Tuple, orig_size: Tuple,):
|
1045 |
+
"""Crop and resize human poses (and their bounding boxes)
|
1046 |
+
"""
|
1047 |
+
orig_height, orig_width = orig_size
|
1048 |
+
top, left, h, w = crop_coords
|
1049 |
+
|
1050 |
+
humanposes_converted_resized = []
|
1051 |
+
for instance in humanposes:
|
1052 |
+
bbox_curr = instance[1:5]
|
1053 |
+
bbox_curr = np.array(bbox_curr)
|
1054 |
+
bbox_curr[0::2] = bbox_curr[0::2] / orig_width
|
1055 |
+
bbox_curr[1::2] = bbox_curr[1::2] / orig_height
|
1056 |
+
|
1057 |
+
xmin, ymin, xmax, ymax = left, top, left + w, top + h
|
1058 |
+
bbox_curr = A.bbox_crop(bbox_curr, x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height,
|
1059 |
+
cols=orig_width)
|
1060 |
+
bbox_curr = np.array(bbox_curr)
|
1061 |
+
if np.all(bbox_curr[1::2]<0) or np.all(bbox_curr[0::2]<0): #bbox is out of range, remove it
|
1062 |
+
continue
|
1063 |
+
if np.all(bbox_curr[1::2]>1.0) or np.all(bbox_curr[0::2]>1.0): #bbox is out of range, remove it
|
1064 |
+
continue
|
1065 |
+
bbox_curr = np.clip(bbox_curr, a_min=0, a_max=1.)
|
1066 |
+
|
1067 |
+
instance[1:5] = bbox_curr
|
1068 |
+
humanposes_converted_resized.append(instance)
|
1069 |
+
|
1070 |
+
# now return all instances, or none if there is no instance
|
1071 |
+
if len(humanposes_converted_resized)>0:
|
1072 |
+
pass
|
1073 |
+
else: #no valid masks remains
|
1074 |
+
return 'none'
|
1075 |
+
|
1076 |
+
humanpose_returned = humanposes_converted_resized
|
1077 |
+
|
1078 |
+
return humanpose_returned
|
1079 |
+
|
1080 |
+
def convert_humanposes_to_string(self, all_humanposes: List[Tuple]):
|
1081 |
+
"""Convert humanposes to a string
|
1082 |
+
range of global orientation: [-1, 1]
|
1083 |
+
range of object pose: [-1, 1]
|
1084 |
+
range of shape (betas): [-3, 3]
|
1085 |
+
range of camera: [-1, 19]
|
1086 |
+
"""
|
1087 |
+
bins = self.coord_bins
|
1088 |
+
|
1089 |
+
instance_final_all = ''
|
1090 |
+
|
1091 |
+
for humanposes in all_humanposes:
|
1092 |
+
human = humanposes[0]
|
1093 |
+
bboxes = humanposes[1:5]
|
1094 |
+
glob = humanposes[5]
|
1095 |
+
global_orient = np.array(humanposes[6:15])
|
1096 |
+
pose = humanposes[15]
|
1097 |
+
pose_params = np.array(humanposes[16:24])
|
1098 |
+
shape = humanposes[24]
|
1099 |
+
shape_params = np.array(humanposes[25:35])
|
1100 |
+
camera = humanposes[35]
|
1101 |
+
camera_params = np.clip(np.array(humanposes[36:]), a_min=-1., a_max=19.)
|
1102 |
+
|
1103 |
+
bboxes_new = [
|
1104 |
+
f"v0={round(bboxes[0] * (bins - 1))}",
|
1105 |
+
f"v1={round(bboxes[1] * (bins - 1))}",
|
1106 |
+
f"v2={round(bboxes[2] * (bins - 1))}",
|
1107 |
+
f"v3={round(bboxes[3] * (bins - 1))}"]
|
1108 |
+
|
1109 |
+
global_orient = 499.5*global_orient
|
1110 |
+
global_orient_new = []
|
1111 |
+
for ii in range(len(global_orient)):
|
1112 |
+
global_orient_curr = f"v0={round(global_orient[ii]+499.5)}"
|
1113 |
+
global_orient_new.append(global_orient_curr)
|
1114 |
+
|
1115 |
+
pose_params_new = []
|
1116 |
+
for ii in range(len(pose_params)):
|
1117 |
+
if pose_params[ii]<512:
|
1118 |
+
pose_params_curr = f"v0={round(pose_params[ii])}"
|
1119 |
+
else:
|
1120 |
+
pose_params_curr = f"v1={round(pose_params[ii] - 512)}"
|
1121 |
+
pose_params_new.append(pose_params_curr)
|
1122 |
+
|
1123 |
+
shape_params = 166.5*shape_params
|
1124 |
+
shape_params_new = []
|
1125 |
+
for ii in range(len(shape_params)):
|
1126 |
+
shape_params_curr = f"v0={round(shape_params[ii]+499.5)}"
|
1127 |
+
shape_params_new.append(shape_params_curr)
|
1128 |
+
|
1129 |
+
camera_params = 49.95*camera_params
|
1130 |
+
camera_params_new = []
|
1131 |
+
for ii in range(len(camera_params)):
|
1132 |
+
camera_params_curr = f"v0={round(camera_params[ii]+49.95)}"
|
1133 |
+
camera_params_new.append(camera_params_curr)
|
1134 |
+
|
1135 |
+
#randomly shuffle everything except bbox part of the sequence
|
1136 |
+
all_strings = [[pose]+pose_params_new, [glob] + global_orient_new, [camera] + camera_params_new, [shape] + shape_params_new ]
|
1137 |
+
rand_perm = torch.randperm(4)
|
1138 |
+
instance_final = [human] + bboxes_new + all_strings[rand_perm[0]] + all_strings[rand_perm[1]] + all_strings[rand_perm[2]] + all_strings[rand_perm[3]]
|
1139 |
+
|
1140 |
+
|
1141 |
+
instance_final = ', '.join(instance_final)
|
1142 |
+
instance_final = instance_final.replace(",", "")
|
1143 |
+
instance_final_all = instance_final_all + instance_final + ' '
|
1144 |
+
|
1145 |
+
return instance_final_all
|
1146 |
+
|
1147 |
+
def load(self, path):
|
1148 |
+
with open(path, 'r') as f:
|
1149 |
+
sample = json.load(f)
|
1150 |
+
|
1151 |
+
return sample
|
1152 |
+
|
1153 |
+
def preprocess(self, sample):
|
1154 |
+
instances = sample
|
1155 |
+
instances = self.convert_humanpose_instance(instances, only_pose=self.only_pose)
|
1156 |
+
return instances
|
1157 |
+
|
1158 |
+
def image_augment(self, humanposes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
1159 |
+
rand_aug_idx=None, resample_mode: str = None):
|
1160 |
+
if humanposes=='none' or self.only_pose:
|
1161 |
+
return humanposes
|
1162 |
+
humanposes = self.humanposes_crop_and_resize(humanposes, crop_coords, orig_size)
|
1163 |
+
return humanposes
|
1164 |
+
|
1165 |
+
def postprocess(self, humanposes):
|
1166 |
+
if humanposes=='none' or self.only_pose:
|
1167 |
+
return humanposes if not self.return_raw else []
|
1168 |
+
if self.return_raw:
|
1169 |
+
return humanposes
|
1170 |
+
humanposes = self.convert_humanposes_to_string(humanposes)
|
1171 |
+
return humanposes
|
1172 |
+
|
1173 |
+
|
1174 |
+
class ColorPaletteTransform(AbstractTransform):
|
1175 |
+
|
1176 |
+
def __init__(self, coord_bins=1000, return_raw=False):
|
1177 |
+
self.coord_bins = coord_bins
|
1178 |
+
self.return_raw = return_raw
|
1179 |
+
|
1180 |
+
def convert_palette_instance(self, instances):
|
1181 |
+
"""Convert colors to v0= v0= ...
|
1182 |
+
"""
|
1183 |
+
length = random.randint(1,7)
|
1184 |
+
instances_converted = np.array(instances[0][str(length)]).flatten().tolist()
|
1185 |
+
return instances_converted
|
1186 |
+
|
1187 |
+
def palette_hflip(self, palettes: List[Tuple], image_size: Tuple, flip: bool):
|
1188 |
+
|
1189 |
+
return palettes
|
1190 |
+
|
1191 |
+
def convert_palettes_to_string(self, all_palettes: List[Tuple]):
|
1192 |
+
"""Convert palettes to a string
|
1193 |
+
"""
|
1194 |
+
|
1195 |
+
colors = []
|
1196 |
+
len_palettes = len(all_palettes)
|
1197 |
+
colors.append(f"v1={round(len_palettes/3)}") # start with the length of the color palette to avoid confusion
|
1198 |
+
for ii in range(len(all_palettes)):
|
1199 |
+
color_new = f"v0={round(all_palettes[ii])}"
|
1200 |
+
colors.append(color_new)
|
1201 |
+
|
1202 |
+
instance_final_all = colors
|
1203 |
+
instance_final_all = ', '.join(instance_final_all)
|
1204 |
+
instance_final_all = instance_final_all.replace(",", "")
|
1205 |
+
|
1206 |
+
return instance_final_all
|
1207 |
+
|
1208 |
+
def load(self, path):
|
1209 |
+
with open(path, 'r') as f:
|
1210 |
+
sample = json.load(f)
|
1211 |
+
return sample
|
1212 |
+
|
1213 |
+
def preprocess(self, sample):
|
1214 |
+
if self.return_raw:
|
1215 |
+
return sample
|
1216 |
+
instances = sample
|
1217 |
+
instances = self.convert_palette_instance(instances)
|
1218 |
+
return instances
|
1219 |
+
|
1220 |
+
def image_augment(self, palettes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
1221 |
+
rand_aug_idx=None, resample_mode: str = None):
|
1222 |
+
return palettes
|
1223 |
+
|
1224 |
+
def postprocess(self, palettes):
|
1225 |
+
if self.return_raw:
|
1226 |
+
return palettes
|
1227 |
+
palettes = self.convert_palettes_to_string(palettes)
|
1228 |
+
return palettes
|
1229 |
+
|
1230 |
+
|
1231 |
+
class SAMInstanceTokTransform(AbstractTransform):
|
1232 |
+
|
1233 |
+
def __init__(self, image_size=224, points_per_side=7, point_order='random'):
|
1234 |
+
self.H, self.W = to_2tuple(image_size)
|
1235 |
+
self.points_per_h, self.points_per_w = to_2tuple(points_per_side)
|
1236 |
+
assert point_order in ['random', 'grid']
|
1237 |
+
self.point_order = point_order
|
1238 |
+
|
1239 |
+
def get_query_points(self):
|
1240 |
+
if self.point_order == 'grid':
|
1241 |
+
# Create and cache grid query points
|
1242 |
+
if not hasattr(self, 'grid_query_points'):
|
1243 |
+
y, x = np.meshgrid(np.linspace(0, self.H, self.points_per_h + 2)[1:-1], np.linspace(0, self.W, self.points_per_w + 2)[1:-1])
|
1244 |
+
grid = np.stack((x, y), axis=2).astype(np.int32)
|
1245 |
+
self.grid_query_points = grid.reshape(-1, 2)
|
1246 |
+
return self.grid_query_points
|
1247 |
+
elif self.point_order == 'random':
|
1248 |
+
# Randomly sample query points
|
1249 |
+
y = np.random.randint(0, self.H, self.points_per_h)
|
1250 |
+
x = np.random.randint(0, self.W, self.points_per_w)
|
1251 |
+
return np.concatenate((x[:,None], y[:,None]), axis=1)
|
1252 |
+
else:
|
1253 |
+
raise ValueError(f"Query point order mode {self.point_order} is not supported.")
|
1254 |
+
|
1255 |
+
def get_target_tokens(self, sample, query_points):
|
1256 |
+
instances_coords = [coords[0] for coords in sample['points']]
|
1257 |
+
tokens = sample['token_ids']
|
1258 |
+
bboxes = sample['bbox']
|
1259 |
+
|
1260 |
+
instance_tokens_per_qpoint = dict()
|
1261 |
+
for point in query_points:
|
1262 |
+
point = (int(point[0].item()), int(point[1].item()))
|
1263 |
+
instance_tokens_per_qpoint[point] = []
|
1264 |
+
for i, (coords, tok, bbox) in enumerate(zip(instances_coords, tokens, bboxes)):
|
1265 |
+
# Calculate the distance from the query point to the instance
|
1266 |
+
distance = cv2.pointPolygonTest(coords, point, measureDist=True)
|
1267 |
+
# If the query point is inside the instance, add its corresponding token
|
1268 |
+
if distance >= 0:
|
1269 |
+
instance_tokens_per_qpoint[point].append((tok, bbox))
|
1270 |
+
|
1271 |
+
return instance_tokens_per_qpoint
|
1272 |
+
|
1273 |
+
def convert_target_tokens_to_string(self, target_tokens):
|
1274 |
+
result_text = []
|
1275 |
+
query_points = list(target_tokens.keys())
|
1276 |
+
# Randomly shuffle query points order (mainly for grid order)
|
1277 |
+
random.shuffle(query_points)
|
1278 |
+
for point in query_points:
|
1279 |
+
|
1280 |
+
# Add query point coordinates to the string
|
1281 |
+
result_text.append('point')
|
1282 |
+
result_text.append(f'v0={point[1]}')
|
1283 |
+
result_text.append(f'v1={point[0]}')
|
1284 |
+
|
1285 |
+
# Randomly shuffle the order of instance tokens per query point
|
1286 |
+
random.shuffle(target_tokens[point])
|
1287 |
+
if len(target_tokens[point]) == 0:
|
1288 |
+
# If no instances tokens are found, add 'none' to the string
|
1289 |
+
result_text.append('none')
|
1290 |
+
else:
|
1291 |
+
for tok, bbox in target_tokens[point]:
|
1292 |
+
result_text.append(f'polygon')
|
1293 |
+
|
1294 |
+
# Add bounding box coordinates to the string
|
1295 |
+
ymin, xmin, ymax, xmax = bbox.astype(np.int32)
|
1296 |
+
result_text.extend([
|
1297 |
+
f'v0={xmin}',
|
1298 |
+
f'v1={ymin}',
|
1299 |
+
f'v2={xmax}',
|
1300 |
+
f'v3={ymax}',
|
1301 |
+
])
|
1302 |
+
|
1303 |
+
# Add instance tokens ids to the string
|
1304 |
+
for idx in tok.tolist():
|
1305 |
+
if idx < 512:
|
1306 |
+
result_text.append(f'v0={idx}')
|
1307 |
+
else:
|
1308 |
+
result_text.append(f'v1={idx - 512}')
|
1309 |
+
|
1310 |
+
return " ".join(result_text)
|
1311 |
+
|
1312 |
+
def load(self, path):
|
1313 |
+
sample = np.load(path, allow_pickle=True)
|
1314 |
+
return sample
|
1315 |
+
|
1316 |
+
def preprocess(self, sample):
|
1317 |
+
for s in sample:
|
1318 |
+
s['token_ids'] = s['token_ids'].astype(np.int32)
|
1319 |
+
return sample
|
1320 |
+
|
1321 |
+
def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
1322 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
1323 |
+
if rand_aug_idx is None:
|
1324 |
+
raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used")
|
1325 |
+
v = v[rand_aug_idx]
|
1326 |
+
return v
|
1327 |
+
|
1328 |
+
def postprocess(self, sample):
|
1329 |
+
query_points = self.get_query_points()
|
1330 |
+
target_tokens = self.get_target_tokens(sample, query_points)
|
1331 |
+
final_string = self.convert_target_tokens_to_string(target_tokens)
|
1332 |
+
return final_string
|
1333 |
+
|
1334 |
+
|
1335 |
+
class CropSettingsTransform(AbstractTransform):
|
1336 |
+
|
1337 |
+
def load(self, path):
|
1338 |
+
sample = np.load(path)
|
1339 |
+
return sample
|
1340 |
+
|
1341 |
+
def preprocess(self, sample):
|
1342 |
+
raise NotImplementedError("CropSettingsTransform does not support preprocessing")
|
1343 |
+
|
1344 |
+
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
1345 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
1346 |
+
raise NotImplementedError("CropSettingsTransform is not meant to be used for image augmentation")
|
1347 |
+
|
1348 |
+
def postprocess(self, sample):
|
1349 |
+
raise NotImplementedError("CropSettingsTransform does not support postprocessing")
|
1350 |
+
|
1351 |
+
|
1352 |
+
class IdentityTransform(AbstractTransform):
|
1353 |
+
|
1354 |
+
def load(self, path):
|
1355 |
+
raise NotImplementedError("IdentityTransform does not support loading")
|
1356 |
+
|
1357 |
+
def preprocess(self, sample):
|
1358 |
+
return sample
|
1359 |
+
|
1360 |
+
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
1361 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
1362 |
+
return val
|
1363 |
+
|
1364 |
+
def postprocess(self, sample):
|
1365 |
+
return sample
|
1366 |
+
|
1367 |
+
|
1368 |
+
class JSONTransform(AbstractTransform):
|
1369 |
+
|
1370 |
+
def load(self, path):
|
1371 |
+
if path.endswith('.json'):
|
1372 |
+
with open(path, 'r') as f:
|
1373 |
+
sample = json.load(f)
|
1374 |
+
elif path.endswith('.json.gz'):
|
1375 |
+
with gzip.open(path, 'rb') as f:
|
1376 |
+
sample = json.load(f)
|
1377 |
+
return sample
|
1378 |
+
|
1379 |
+
def preprocess(self, sample):
|
1380 |
+
return sample
|
1381 |
+
|
1382 |
+
def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
|
1383 |
+
rand_aug_idx: Optional[int], resample_mode: str = None):
|
1384 |
+
return val
|
1385 |
+
|
1386 |
+
def postprocess(self, sample):
|
1387 |
+
return sample
|
fourm/data/multimodal_dataset_folder.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import os.path
|
16 |
+
import pickle
|
17 |
+
import random
|
18 |
+
from copy import deepcopy
|
19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from torchvision.datasets.vision import VisionDataset
|
23 |
+
|
24 |
+
from fourm.data.modality_transforms import AbstractTransform, get_transform_key
|
25 |
+
|
26 |
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp', '.jpx', '.npy', '.npz')
|
27 |
+
|
28 |
+
UNIFIED_EXTENSIONS = IMG_EXTENSIONS + ('.json', '.txt', '.json.gz')
|
29 |
+
|
30 |
+
|
31 |
+
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
|
32 |
+
"""Checks if a file is an allowed extension.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
filename (string): path to a file
|
36 |
+
extensions (tuple of strings): extensions to consider (lowercase)
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
bool: True if the filename ends with one of given extensions
|
40 |
+
"""
|
41 |
+
return filename.lower().endswith(extensions)
|
42 |
+
|
43 |
+
|
44 |
+
def is_image_file(filename: str) -> bool:
|
45 |
+
"""Checks if a file is an allowed image extension.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
filename (string): path to a file
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
bool: True if the filename ends with a known image extension
|
52 |
+
"""
|
53 |
+
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
|
54 |
+
|
55 |
+
|
56 |
+
def make_dataset(
|
57 |
+
directory: str,
|
58 |
+
class_to_idx: Dict[str, int],
|
59 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
60 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
61 |
+
cache_path: Optional[str] = None,
|
62 |
+
) -> List[Tuple[str, int]]:
|
63 |
+
if cache_path is not None and os.path.exists(cache_path):
|
64 |
+
# Load cached file paths from disk if it exists
|
65 |
+
with open(cache_path, 'rb') as f:
|
66 |
+
return pickle.load(f)
|
67 |
+
instances = []
|
68 |
+
directory = os.path.expanduser(directory)
|
69 |
+
both_none = extensions is None and is_valid_file is None
|
70 |
+
both_something = extensions is not None and is_valid_file is not None
|
71 |
+
if both_none or both_something:
|
72 |
+
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
|
73 |
+
if extensions is not None:
|
74 |
+
def is_valid_file(x: str) -> bool:
|
75 |
+
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
|
76 |
+
is_valid_file = cast(Callable[[str], bool], is_valid_file)
|
77 |
+
for target_class in sorted(class_to_idx.keys()):
|
78 |
+
class_index = class_to_idx[target_class]
|
79 |
+
target_dir = os.path.join(directory, target_class)
|
80 |
+
if not os.path.isdir(target_dir):
|
81 |
+
continue
|
82 |
+
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
|
83 |
+
for fname in sorted(fnames):
|
84 |
+
path = os.path.join(root, fname)
|
85 |
+
if is_valid_file(path):
|
86 |
+
item = path, class_index
|
87 |
+
instances.append(item)
|
88 |
+
if cache_path is not None:
|
89 |
+
# Cache all file paths s.t. setting up the dataloader is instant in the future
|
90 |
+
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
91 |
+
with open(cache_path, 'wb') as f:
|
92 |
+
pickle.dump(instances, f)
|
93 |
+
return instances
|
94 |
+
|
95 |
+
|
96 |
+
class DatasetFolder(VisionDataset):
|
97 |
+
"""A generic data loader where the samples are arranged in this way: ::
|
98 |
+
|
99 |
+
root/class_x/xxx.ext
|
100 |
+
root/class_x/xxy.ext
|
101 |
+
root/class_x/xxz.ext
|
102 |
+
|
103 |
+
root/class_y/123.ext
|
104 |
+
root/class_y/nsdf3.ext
|
105 |
+
root/class_y/asd932_.ext
|
106 |
+
|
107 |
+
Args:
|
108 |
+
root (string): Root directory path.
|
109 |
+
loader (callable): A function to load a sample given its path.
|
110 |
+
extensions (tuple[string]): A list of allowed extensions.
|
111 |
+
both extensions and is_valid_file should not be passed.
|
112 |
+
transform (callable, optional): A function/transform that takes in
|
113 |
+
a sample and returns a transformed version.
|
114 |
+
E.g, ``transforms.RandomCrop`` for images.
|
115 |
+
target_transform (callable, optional): A function/transform that takes
|
116 |
+
in the target and transforms it.
|
117 |
+
is_valid_file (callable, optional): A function that takes path of a file
|
118 |
+
and check if the file is a valid file (used to check of corrupt logs)
|
119 |
+
both extensions and is_valid_file should not be passed.
|
120 |
+
|
121 |
+
Attributes:
|
122 |
+
classes (list): List of the class names sorted alphabetically.
|
123 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
124 |
+
samples (list): List of (sample path, class_index) tuples
|
125 |
+
targets (list): The class_index value for each image in the dataset
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
root: str,
|
131 |
+
loader: Callable[[str], Any],
|
132 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
133 |
+
transform: Optional[Callable] = None,
|
134 |
+
target_transform: Optional[Callable] = None,
|
135 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
136 |
+
) -> None:
|
137 |
+
super(DatasetFolder, self).__init__(root, transform=transform,
|
138 |
+
target_transform=target_transform)
|
139 |
+
classes, class_to_idx = self._find_classes(self.root)
|
140 |
+
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
|
141 |
+
if len(samples) == 0:
|
142 |
+
msg = "Found 0 logs in subfolders of: {}\n".format(self.root)
|
143 |
+
if extensions is not None:
|
144 |
+
msg += "Supported extensions are: {}".format(",".join(extensions))
|
145 |
+
raise RuntimeError(msg)
|
146 |
+
|
147 |
+
self.loader = loader
|
148 |
+
self.extensions = extensions
|
149 |
+
|
150 |
+
self.classes = classes
|
151 |
+
self.class_to_idx = class_to_idx
|
152 |
+
self.samples = samples
|
153 |
+
self.targets = [s[1] for s in samples]
|
154 |
+
|
155 |
+
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
|
156 |
+
"""
|
157 |
+
Finds the class folders in a dataset.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
dir (string): Root directory path.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
|
164 |
+
|
165 |
+
Ensures:
|
166 |
+
No class is a subdirectory of another.
|
167 |
+
"""
|
168 |
+
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
|
169 |
+
classes.sort()
|
170 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
171 |
+
return classes, class_to_idx
|
172 |
+
|
173 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
174 |
+
"""
|
175 |
+
Args:
|
176 |
+
index (int): Index
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
tuple: (sample, target) where target is class_index of the target class.
|
180 |
+
"""
|
181 |
+
while True:
|
182 |
+
try:
|
183 |
+
path, target = self.samples[index]
|
184 |
+
sample = self.loader(path)
|
185 |
+
break
|
186 |
+
except Exception as e:
|
187 |
+
print(e)
|
188 |
+
index = random.randint(0, len(self.samples) - 1)
|
189 |
+
|
190 |
+
if self.transform is not None:
|
191 |
+
sample = self.transform(sample)
|
192 |
+
if self.target_transform is not None:
|
193 |
+
target = self.target_transform(target)
|
194 |
+
|
195 |
+
return sample, target
|
196 |
+
|
197 |
+
def __len__(self) -> int:
|
198 |
+
return len(self.samples)
|
199 |
+
|
200 |
+
|
201 |
+
class MultiModalDatasetFolder(VisionDataset):
|
202 |
+
"""A generic multi-modal dataset loader where the samples are arranged in this way: ::
|
203 |
+
|
204 |
+
root/modality_a/class_x/xxx.ext
|
205 |
+
root/modality_a/class_y/xxy.ext
|
206 |
+
root/modality_a/class_z/xxz.ext
|
207 |
+
|
208 |
+
root/modality_b/class_x/xxx.ext
|
209 |
+
root/modality_b/class_y/xxy.ext
|
210 |
+
root/modality_b/class_z/xxz.ext
|
211 |
+
|
212 |
+
Args:
|
213 |
+
root (string): Root directory path.
|
214 |
+
modalities (list): List of modalities as strings
|
215 |
+
modality_paths (dict): Dict of paths to modalities
|
216 |
+
modality_transforms (dict): Dict of transforms for each modality
|
217 |
+
loader (callable): A function to load a sample given its path.
|
218 |
+
transform (callable, optional): A function/transform that takes in
|
219 |
+
a sample and returns a transformed version.
|
220 |
+
E.g, ``transforms.RandomCrop`` for images.
|
221 |
+
target_transform (callable, optional): A function/transform that takes
|
222 |
+
in the target and transforms it.
|
223 |
+
is_valid_file (callable, optional): A function that takes path of a file
|
224 |
+
and check if the file is a valid file (used to check of corrupt logs)
|
225 |
+
both extensions and is_valid_file should not be passed.
|
226 |
+
max_samples (int, optional): Maximum number of samples to load. If None, all samples are loaded.
|
227 |
+
pre_shuffle (bool, optional): Whether to shuffle the sample during the init.
|
228 |
+
return_paths (bool, optional): Whether to return the paths of the samples.
|
229 |
+
cache (bool, optional): Whether to cache the samples in memory. If True, the samples are loaded only once and then cached in memory.
|
230 |
+
|
231 |
+
Attributes:
|
232 |
+
classes (list): List of the class names sorted alphabetically.
|
233 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
234 |
+
samples (list): List of (sample path, class_index) tuples
|
235 |
+
targets (list): The class_index value for each image in the dataset
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(
|
239 |
+
self,
|
240 |
+
root: str,
|
241 |
+
modalities: List[str],
|
242 |
+
modality_paths: Dict[str, str],
|
243 |
+
modality_transforms: Dict[str, AbstractTransform],
|
244 |
+
transform: Optional[Callable] = None,
|
245 |
+
target_transform: Optional[Callable] = None,
|
246 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
247 |
+
max_samples: Optional[int] = None,
|
248 |
+
pre_shuffle: bool = False,
|
249 |
+
cache: bool = False,
|
250 |
+
return_path: bool = False,
|
251 |
+
) -> None:
|
252 |
+
super(MultiModalDatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)
|
253 |
+
self.modalities = modalities
|
254 |
+
# If modality_paths is not provided, use the default paths
|
255 |
+
self.modality_paths = modality_paths
|
256 |
+
for mod in self.modalities:
|
257 |
+
if mod not in self.modality_paths:
|
258 |
+
modality_paths[mod] = mod
|
259 |
+
self.modality_transforms = modality_transforms
|
260 |
+
self.return_path = return_path
|
261 |
+
|
262 |
+
classes, class_to_idx = self._find_classes(os.path.join(self.root, list(self.modality_paths.values())[0]))
|
263 |
+
extensions = UNIFIED_EXTENSIONS if is_valid_file is None else None
|
264 |
+
|
265 |
+
samples = {
|
266 |
+
mod: make_dataset(
|
267 |
+
os.path.join(self.root, f'{self.modality_paths[mod]}'),
|
268 |
+
class_to_idx,
|
269 |
+
extensions,
|
270 |
+
is_valid_file,
|
271 |
+
cache_path=os.path.join(self.root, 'dataloader_cache', f'{self.modality_paths[mod]}.pkl') if cache else None)
|
272 |
+
for mod in self.modalities
|
273 |
+
}
|
274 |
+
|
275 |
+
for mod, mod_samples in samples.items():
|
276 |
+
if len(mod_samples) == 0:
|
277 |
+
msg = "Found 0 logs in subfolders of: {}\n".format(os.path.join(self.root, f'{self.modality_paths[mod]}'))
|
278 |
+
if extensions is not None:
|
279 |
+
msg += "Supported extensions are: {}".format(",".join(extensions))
|
280 |
+
raise RuntimeError(msg)
|
281 |
+
|
282 |
+
self.extensions = extensions
|
283 |
+
|
284 |
+
self.classes = classes
|
285 |
+
self.class_to_idx = class_to_idx
|
286 |
+
self.samples = samples
|
287 |
+
|
288 |
+
# Select random subset of dataset if so specified
|
289 |
+
if isinstance(max_samples, int):
|
290 |
+
total_samples = len(list(self.samples.values())[0])
|
291 |
+
np.random.seed(0)
|
292 |
+
permutation = np.random.permutation(total_samples)
|
293 |
+
for task in samples:
|
294 |
+
self.samples[task] = [self.samples[task][i] for i in permutation][:max_samples]
|
295 |
+
|
296 |
+
if pre_shuffle:
|
297 |
+
total_samples = len(list(self.samples.values())[0])
|
298 |
+
np.random.seed(100)
|
299 |
+
permutation = np.random.permutation(total_samples)
|
300 |
+
for task in samples:
|
301 |
+
self.samples[task] = [self.samples[task][i] for i in permutation]
|
302 |
+
|
303 |
+
self.cache = {}
|
304 |
+
self.imgs = self.samples
|
305 |
+
|
306 |
+
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
|
307 |
+
"""
|
308 |
+
Finds the class folders in a dataset.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
dir (string): Root directory path.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
|
315 |
+
|
316 |
+
Ensures:
|
317 |
+
No class is a subdirectory of another.
|
318 |
+
"""
|
319 |
+
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
|
320 |
+
classes.sort()
|
321 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
322 |
+
return classes, class_to_idx
|
323 |
+
|
324 |
+
def get_class_and_file(self, path: str) -> Tuple[str, str]:
|
325 |
+
""" Extracts the class and file name from a path. """
|
326 |
+
class_id, file_name = path.split('/')[-2:]
|
327 |
+
file_name = file_name.split('.')[0]
|
328 |
+
return class_id, file_name
|
329 |
+
|
330 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
index (int): Index
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
tuple: (sample, target) where target is class_index of the target class.
|
337 |
+
"""
|
338 |
+
if index in self.cache:
|
339 |
+
sample_dict, target = deepcopy(self.cache[index])
|
340 |
+
else:
|
341 |
+
sample_dict = {}
|
342 |
+
for mod in self.modalities:
|
343 |
+
path, target = self.samples[mod][index]
|
344 |
+
sample = self.modality_transforms[get_transform_key(mod)].load(path)
|
345 |
+
sample_dict[mod] = sample
|
346 |
+
# self.cache[index] = deepcopy((sample_dict, target))
|
347 |
+
|
348 |
+
if self.transform is not None:
|
349 |
+
sample_dict = self.transform(sample_dict)
|
350 |
+
if self.target_transform is not None:
|
351 |
+
target = self.target_transform(target)
|
352 |
+
|
353 |
+
sample_dict['class_idx'] = target
|
354 |
+
|
355 |
+
if self.return_path and not index in self.cache:
|
356 |
+
class_id, file_name = self.get_class_and_file(path)
|
357 |
+
sample_dict['class_id'] = class_id
|
358 |
+
sample_dict['file_name'] = file_name
|
359 |
+
|
360 |
+
return sample_dict
|
361 |
+
|
362 |
+
def __len__(self) -> int:
|
363 |
+
return len(list(self.samples.values())[0])
|
fourm/data/pretrain_utils.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import copy
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import yaml
|
18 |
+
|
19 |
+
import fourm.utils as utils
|
20 |
+
|
21 |
+
from fourm.data import (CenterCropImageAugmenter, EmptyAugmenter,
|
22 |
+
PreTokenizedImageAugmenter,RandomCropImageAugmenter, build_fm_pretraining_dataset,
|
23 |
+
build_huggingface_pretraining_dataloader,
|
24 |
+
build_wds_fm_pretraining_dataloader)
|
25 |
+
from fourm.data.modality_transforms import CaptionTransform
|
26 |
+
from fourm.data.modality_info import MODALITY_TRANSFORMS
|
27 |
+
|
28 |
+
|
29 |
+
def setup_sampling_mod_info(dataset_config, modality_info):
|
30 |
+
# Subset of modality info for each dataset
|
31 |
+
|
32 |
+
# Input and output modalities for one dataset
|
33 |
+
in_domains = sorted(dataset_config['in_domains'].split('-'))
|
34 |
+
out_domains = sorted(dataset_config['out_domains'].split('-'))
|
35 |
+
all_domains = sorted(list(set(in_domains) | set(out_domains)))
|
36 |
+
|
37 |
+
mod_info = copy.deepcopy(modality_info)
|
38 |
+
mod_info = {mod: mod_info[mod] for mod in all_domains}
|
39 |
+
|
40 |
+
# Dirichlet concentration parameter (Alpha)
|
41 |
+
if dataset_config.get('alphas_config', None) is None:
|
42 |
+
for mod in mod_info:
|
43 |
+
mod_info[mod]["input_alphas"] = [0.]
|
44 |
+
mod_info[mod]["target_alphas"] = [0.]
|
45 |
+
|
46 |
+
if 'input_alphas' in dataset_config:
|
47 |
+
input_alphas = dataset_config['input_alphas'].split('-')
|
48 |
+
if len(input_alphas) == 1:
|
49 |
+
input_alphas = [float(input_alphas[0])] * len(in_domains)
|
50 |
+
else:
|
51 |
+
input_alphas = [float(alpha) for alpha in input_alphas]
|
52 |
+
for mod, alpha in zip(in_domains, input_alphas):
|
53 |
+
mod_info[mod]['input_alphas'] = [alpha]
|
54 |
+
|
55 |
+
if 'target_alphas' in dataset_config:
|
56 |
+
target_alphas = dataset_config['target_alphas'].split('-')
|
57 |
+
if len(target_alphas) == 1:
|
58 |
+
target_alphas = [float(target_alphas[0])] * len(out_domains)
|
59 |
+
else:
|
60 |
+
target_alphas = [float(alpha) for alpha in target_alphas]
|
61 |
+
for mod, alpha in zip(out_domains, target_alphas):
|
62 |
+
mod_info[mod]["target_alphas"] = [alpha]
|
63 |
+
|
64 |
+
sampling_weights = None
|
65 |
+
else:
|
66 |
+
print(f"Loading alphas config from: {dataset_config['alphas_config']}")
|
67 |
+
with open(dataset_config['alphas_config'], "r") as f:
|
68 |
+
alphas_config = yaml.safe_load(f)
|
69 |
+
|
70 |
+
if 'sampling_weights' in alphas_config:
|
71 |
+
sampling_weights = alphas_config['sampling_weights']
|
72 |
+
alphas_config = alphas_config['alphas_mixture']
|
73 |
+
else:
|
74 |
+
sampling_weights = None
|
75 |
+
|
76 |
+
for mod in mod_info:
|
77 |
+
mod_info[mod]["input_alphas"] = alphas_config[mod]["input_alphas"]
|
78 |
+
mod_info[mod]["target_alphas"] = alphas_config[mod]["target_alphas"]
|
79 |
+
if modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']:
|
80 |
+
mod_info[mod]['keep'] = alphas_config[mod]['keep']
|
81 |
+
|
82 |
+
return mod_info, sampling_weights
|
83 |
+
|
84 |
+
def get_train_dataloader(dataset_config, modality_info, sampling_weights, text_tokenizer, input_size,
|
85 |
+
num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
|
86 |
+
num_tasks, num_workers, dataset_batch_size=None, epoch_size=None):
|
87 |
+
|
88 |
+
in_domains = sorted(list(dataset_config['in_domains'].split('-')))
|
89 |
+
out_domains = sorted(list(dataset_config['out_domains'].split('-')))
|
90 |
+
all_domains = sorted(list(set(in_domains) | set(out_domains)))
|
91 |
+
|
92 |
+
modality_transforms = MODALITY_TRANSFORMS
|
93 |
+
if 'caption' in modality_transforms:
|
94 |
+
modality_transforms['caption'] = CaptionTransform(
|
95 |
+
aligned_captions=dataset_config.get('aligned_captions', True)
|
96 |
+
)
|
97 |
+
|
98 |
+
if dataset_config['type'] == 'multimodal':
|
99 |
+
|
100 |
+
is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
|
101 |
+
if is_pretokenized:
|
102 |
+
# Multi-modal training data augmentation (uses pre-tokenized data augmentation)
|
103 |
+
image_augmenter = PreTokenizedImageAugmenter(
|
104 |
+
target_size=input_size,
|
105 |
+
no_aug=(not dataset_config.get('tok_train_aug', True)),
|
106 |
+
main_domain=dataset_config['main_augment_domain']
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
image_augmenter = RandomCropImageAugmenter(
|
110 |
+
target_size=input_size,
|
111 |
+
hflip=dataset_config.get('hflip'),
|
112 |
+
crop_scale=tuple(dataset_config.get('crop_scale')),
|
113 |
+
crop_ratio=tuple(dataset_config.get('crop_ratio')),
|
114 |
+
)
|
115 |
+
|
116 |
+
# Input and target token ranges
|
117 |
+
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
|
118 |
+
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
|
119 |
+
min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
|
120 |
+
min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
|
121 |
+
min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
|
122 |
+
min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens
|
123 |
+
|
124 |
+
|
125 |
+
if dataset_config['use_wds']:
|
126 |
+
# Using webdataset
|
127 |
+
loader = build_wds_fm_pretraining_dataloader(
|
128 |
+
data_path=dataset_config['data_path'], all_domains=all_domains,
|
129 |
+
modality_info=modality_info, modality_transforms=modality_transforms,
|
130 |
+
image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
|
131 |
+
input_tokens_range=(min_input_tokens, num_input_tokens),
|
132 |
+
target_tokens_range=(min_target_tokens, num_target_tokens),
|
133 |
+
num_gpus=num_tasks, num_workers=num_workers,
|
134 |
+
batch_size=dataset_batch_size, epoch_size=epoch_size,
|
135 |
+
modality_name_map=dataset_config.get('modality_name_map', None),
|
136 |
+
shuffle_buffer_load=dataset_config.get('wds_shuffle_buffer_tar', 1_000),
|
137 |
+
shuffle_buffer_repeat=dataset_config.get('wds_shuffle_buffer_repeat', 1_000),
|
138 |
+
n_repeats=dataset_config.get('wds_n_repeats', 1),
|
139 |
+
sampling_weights=sampling_weights,
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
dataset_train = build_fm_pretraining_dataset(
|
143 |
+
data_path=dataset_config['data_path'],
|
144 |
+
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
|
145 |
+
image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
|
146 |
+
input_tokens_range=(min_input_tokens, num_input_tokens),
|
147 |
+
target_tokens_range=(min_target_tokens, num_target_tokens)
|
148 |
+
)
|
149 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
150 |
+
dataset_train, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=True, drop_last=True,
|
151 |
+
)
|
152 |
+
# DataLoader has batch size 1 as it then gets collated through the Mixture dataloader
|
153 |
+
loader = torch.utils.data.DataLoader(
|
154 |
+
dataset_train, sampler=sampler_train,
|
155 |
+
batch_size=1, num_workers=0,
|
156 |
+
pin_memory=False, drop_last=True,
|
157 |
+
collate_fn=lambda x: x[0],
|
158 |
+
)
|
159 |
+
|
160 |
+
elif dataset_config['type'] == 'huggingface':
|
161 |
+
|
162 |
+
# Input and target token ranges
|
163 |
+
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
|
164 |
+
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
|
165 |
+
|
166 |
+
if dataset_config.get('use_wds', False):
|
167 |
+
raise NotImplementedError('Webdataset not yet implemented for huggingface datasets.')
|
168 |
+
else:
|
169 |
+
loader = build_huggingface_pretraining_dataloader(
|
170 |
+
data_path=dataset_config['data_path'], all_domains=all_domains,
|
171 |
+
modality_info=modality_info, modality_transforms=modality_transforms,
|
172 |
+
image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
|
173 |
+
input_tokens_range=(num_input_tokens, num_input_tokens),
|
174 |
+
target_tokens_range=(num_target_tokens, num_target_tokens),
|
175 |
+
num_gpus=num_tasks, num_workers=num_workers,
|
176 |
+
batch_size=dataset_batch_size, epoch_size=epoch_size,
|
177 |
+
split='train', streaming=True, rename_text_to_caption=True,
|
178 |
+
shuffle_buffer_load=dataset_config.get('shuffle_buffer_load', 1_000),
|
179 |
+
shuffle_seed=0,
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
raise NotImplementedError(f'Dataset type {dataset_config["type"]} not implemented.')
|
183 |
+
|
184 |
+
return loader
|
185 |
+
|
186 |
+
|
187 |
+
def cfgs_get(key, val_config, dataset_name, train_configs, default=None):
|
188 |
+
""" Try to retrieve a key from the validation set config.
|
189 |
+
If it does not exist, default to retrieving it from the train set config
|
190 |
+
with the same dataset name.
|
191 |
+
"""
|
192 |
+
return val_config.get(key, train_configs[dataset_name].get(key, default))
|
193 |
+
|
194 |
+
|
195 |
+
def get_val_dataloader(dataset_config, dataset_name, train_configs, modality_info, sampling_weights, text_tokenizer,
|
196 |
+
input_size, num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
|
197 |
+
fixed_eval, fixed_eval_input_tokens, fixed_eval_target_tokens,
|
198 |
+
dist_eval, num_tasks, num_workers, batch_size, pin_mem):
|
199 |
+
|
200 |
+
in_domains = sorted(list(cfgs_get('in_domains', dataset_config, dataset_name, train_configs).split('-')))
|
201 |
+
out_domains = sorted(list(cfgs_get('out_domains', dataset_config, dataset_name, train_configs).split('-')))
|
202 |
+
all_domains = sorted(list(set(in_domains) | set(out_domains)))
|
203 |
+
|
204 |
+
modality_transforms = MODALITY_TRANSFORMS
|
205 |
+
if 'caption' in modality_transforms:
|
206 |
+
modality_transforms['caption'] = CaptionTransform(
|
207 |
+
aligned_captions=cfgs_get('aligned_captions', dataset_config, dataset_name, train_configs, True)
|
208 |
+
)
|
209 |
+
|
210 |
+
dataset_type = cfgs_get('type', dataset_config, dataset_name, train_configs)
|
211 |
+
|
212 |
+
if dataset_type == 'multimodal':
|
213 |
+
|
214 |
+
main_augment_domain = cfgs_get('main_augment_domain', dataset_config, dataset_name, train_configs)
|
215 |
+
is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
|
216 |
+
if is_pretokenized:
|
217 |
+
eval_image_augmenter = PreTokenizedImageAugmenter(
|
218 |
+
target_size=input_size, no_aug=True, main_domain=main_augment_domain
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
eval_image_augmenter = CenterCropImageAugmenter(
|
222 |
+
target_size=input_size, main_domain=main_augment_domain
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
if fixed_eval:
|
227 |
+
input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
|
228 |
+
target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
|
229 |
+
else:
|
230 |
+
# Input and target token ranges
|
231 |
+
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
|
232 |
+
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
|
233 |
+
min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
|
234 |
+
min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
|
235 |
+
min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
|
236 |
+
min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens
|
237 |
+
input_tokens_range = (min_input_tokens, num_input_tokens)
|
238 |
+
target_tokens_range = (min_target_tokens, num_target_tokens)
|
239 |
+
|
240 |
+
dataset_val = build_fm_pretraining_dataset(
|
241 |
+
data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs),
|
242 |
+
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
|
243 |
+
image_augmenter=eval_image_augmenter, text_tokenizer=text_tokenizer,
|
244 |
+
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range
|
245 |
+
)
|
246 |
+
|
247 |
+
print("Warning: Eval stats may vary slightly as the masking applied on images is random.")
|
248 |
+
if dist_eval:
|
249 |
+
if len(dataset_val) % num_tasks != 0:
|
250 |
+
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
251 |
+
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
252 |
+
'equal num of samples per-process.')
|
253 |
+
sampler_val = torch.utils.data.DistributedSampler(
|
254 |
+
dataset_val, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=False)
|
255 |
+
else:
|
256 |
+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
257 |
+
loader = torch.utils.data.DataLoader(
|
258 |
+
dataset_val, sampler=sampler_val,
|
259 |
+
batch_size=batch_size,
|
260 |
+
num_workers=num_workers,
|
261 |
+
pin_memory=pin_mem,
|
262 |
+
drop_last=False,
|
263 |
+
)
|
264 |
+
|
265 |
+
elif dataset_type == 'huggingface':
|
266 |
+
|
267 |
+
if fixed_eval:
|
268 |
+
input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
|
269 |
+
target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
|
270 |
+
else:
|
271 |
+
# Input and target token ranges
|
272 |
+
num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
|
273 |
+
num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
|
274 |
+
input_tokens_range = (num_input_tokens, num_input_tokens)
|
275 |
+
target_tokens_range = (num_target_tokens, num_target_tokens)
|
276 |
+
|
277 |
+
loader = build_huggingface_pretraining_dataloader(
|
278 |
+
data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs),
|
279 |
+
all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
|
280 |
+
image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
|
281 |
+
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
|
282 |
+
num_gpus=num_tasks, num_workers=num_workers,
|
283 |
+
batch_size=batch_size, epoch_size=None,
|
284 |
+
split='validation', streaming=True, rename_text_to_caption=True,
|
285 |
+
shuffle_buffer_load=cfgs_get('shuffle_buffer_load', dataset_config, dataset_name, train_configs, 1_000),
|
286 |
+
shuffle_seed=0,
|
287 |
+
)
|
288 |
+
|
289 |
+
else:
|
290 |
+
raise NotImplementedError(f'Dataset type {dataset_type} not implemented.')
|
291 |
+
|
292 |
+
return loader
|
fourm/data/transfer_utils.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
|
16 |
+
def convert_samples_to_mod_dict(samples, input_mod, target_mod, num_input_tokens, num_target_tokens):
|
17 |
+
"""Converts a sample (e.g. a batch of RGB images) to a mod dict that can be passed directly to FourM.
|
18 |
+
Assumes both the input modality and target modality are dense tasks.
|
19 |
+
"""
|
20 |
+
|
21 |
+
B = samples.shape[0]
|
22 |
+
device = samples.device
|
23 |
+
|
24 |
+
if input_mod == target_mod:
|
25 |
+
assert(num_input_tokens == num_target_tokens)
|
26 |
+
mod_dict = {
|
27 |
+
input_mod: {
|
28 |
+
'tensor': samples,
|
29 |
+
'input_mask': torch.zeros((B, num_input_tokens), dtype=torch.bool, device=device),
|
30 |
+
'target_mask': torch.zeros((B, num_target_tokens), dtype=torch.bool, device=device),
|
31 |
+
'decoder_attention_mask': torch.zeros((B, num_target_tokens), dtype=torch.int, device=device),
|
32 |
+
},
|
33 |
+
}
|
34 |
+
mod_dict[input_mod]['decoder_attention_mask'][:, 0] = num_target_tokens
|
35 |
+
|
36 |
+
else:
|
37 |
+
mod_dict = {
|
38 |
+
input_mod: {
|
39 |
+
'tensor': samples,
|
40 |
+
'input_mask': torch.zeros((B, num_input_tokens), dtype=torch.bool, device=samples.device),
|
41 |
+
'target_mask': torch.ones((B, num_input_tokens), dtype=torch.bool, device=samples.device),
|
42 |
+
'decoder_attention_mask': torch.zeros((B, num_input_tokens), dtype=torch.int, device=samples.device),
|
43 |
+
},
|
44 |
+
target_mod: {
|
45 |
+
'tensor': torch.zeros((B, num_target_tokens), dtype=torch.long, device=samples.device),
|
46 |
+
'input_mask': torch.ones((B, num_target_tokens), dtype=torch.bool, device=samples.device),
|
47 |
+
'target_mask': torch.zeros((B, num_target_tokens), dtype=torch.bool, device=samples.device),
|
48 |
+
'decoder_attention_mask': torch.ones((B, num_target_tokens), dtype=torch.int, device=samples.device),
|
49 |
+
},
|
50 |
+
}
|
51 |
+
mod_dict[target_mod]['decoder_attention_mask'][:, 0] = num_target_tokens
|
52 |
+
|
53 |
+
return mod_dict
|
fourm/data/unified_datasets.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import copy
|
15 |
+
import io
|
16 |
+
import itertools
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
from functools import partial
|
20 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional
|
21 |
+
|
22 |
+
import braceexpand
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
import webdataset as wds
|
26 |
+
from PIL import Image
|
27 |
+
from torch.utils.data import IterableDataset
|
28 |
+
from torch.utils.data._utils.collate import default_collate
|
29 |
+
from torchvision import transforms
|
30 |
+
from webdataset.filters import pipelinefilter, reraise_exception
|
31 |
+
from webdataset.handlers import warn_and_continue
|
32 |
+
|
33 |
+
try:
|
34 |
+
# Optionally load huggingface datasets
|
35 |
+
from datasets import load_dataset
|
36 |
+
from datasets.distributed import split_dataset_by_node
|
37 |
+
except ImportError:
|
38 |
+
print("Huggingface datasets not installed. Please install with `pip install datasets`.")
|
39 |
+
|
40 |
+
from fourm.data.masking import TransferMasking, UnifiedMasking
|
41 |
+
from fourm.data.modality_transforms import (CropSettingsTransform, IdentityTransform,
|
42 |
+
MaskTransform, UnifiedDataTransform,
|
43 |
+
get_transform_key)
|
44 |
+
from fourm.data.multimodal_dataset_folder import MultiModalDatasetFolder
|
45 |
+
from fourm.utils.dist import get_rank, get_world_size
|
46 |
+
|
47 |
+
|
48 |
+
def build_fm_pretraining_dataset(
|
49 |
+
data_path, all_domains, modality_info, modality_transforms,
|
50 |
+
image_augmenter, text_tokenizer,
|
51 |
+
input_tokens_range, target_tokens_range,
|
52 |
+
sampling_weights=None):
|
53 |
+
"""Builds the FourM pre-training dataset based on the given arguments.
|
54 |
+
This function should mainly used for smaller datasets (e.g. validation sets),
|
55 |
+
while large training sets should be loaded with build_wds_fm_pretraining_dataloader in webdataset format.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
data_path: Path to the dataset.
|
59 |
+
all_domains: List of all modalities to be used.
|
60 |
+
modality_info: Dictionary containing information about the modalities.
|
61 |
+
modality_transforms: Dictionary containing the transforms for each modality.
|
62 |
+
image_augmenter: Image augmenter.
|
63 |
+
text_tokenizer: Text tokenizer (for sequence modalities).
|
64 |
+
input_tokens_range: Range of the input token budget.
|
65 |
+
target_tokens_range: Range of the target token budget.
|
66 |
+
sampling_weights: Sampling weights for the mixture of Dirichlet distributions.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
FourM pre-training dataset as a PyTorch Dataset.
|
70 |
+
"""
|
71 |
+
|
72 |
+
transform = transforms.Compose([
|
73 |
+
UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter),
|
74 |
+
UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer,
|
75 |
+
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
|
76 |
+
sampling_weights=sampling_weights),
|
77 |
+
])
|
78 |
+
|
79 |
+
# Remove vq domains that require a tokenizer
|
80 |
+
modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)]
|
81 |
+
# If we are using a pre-tokenized modality, we default to pre-computed crop settings
|
82 |
+
if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]):
|
83 |
+
modalities_without_vq.append("crop_settings")
|
84 |
+
modality_transforms = copy.deepcopy(modality_transforms)
|
85 |
+
modality_transforms["crop_settings"] = CropSettingsTransform()
|
86 |
+
|
87 |
+
modality_paths = {mod: modality_info[mod]['path'] for mod in modality_info if modality_info[mod].get('path', None) is not None}
|
88 |
+
|
89 |
+
return MultiModalDatasetFolder(root=data_path, modalities=modalities_without_vq, modality_paths=modality_paths,
|
90 |
+
modality_transforms=modality_transforms, transform=transform)
|
91 |
+
|
92 |
+
|
93 |
+
def build_fm_transfer_dataset(
|
94 |
+
data_path, modality_info, transform, modality_transforms, all_domains,
|
95 |
+
load_mask_valid: bool = False, max_samples: Optional[int] = None,
|
96 |
+
pre_shuffle: bool = False, cache: bool = False):
|
97 |
+
"""Builds the FourM transfer dataset based on the given arguments.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
data_path: Path to the dataset.
|
101 |
+
modality_info: Dictionary containing information about the modalities.
|
102 |
+
transform: Transform to be applied to the dataset.
|
103 |
+
modality_transforms: Dictionary containing the transforms for each modality.
|
104 |
+
all_domains: List of all modalities to be used.
|
105 |
+
load_mask_valid: Whether to load the mask_valid "modality".
|
106 |
+
max_samples: Maximum number of samples to load.
|
107 |
+
pre_shuffle: Whether to shuffle the dataset before loading.
|
108 |
+
cache: Whether to cache the dataset in memory.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
FourM transfer dataset as a PyTorch Dataset.
|
112 |
+
"""
|
113 |
+
|
114 |
+
# Remove vq domains that require a tokenizer
|
115 |
+
modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)]
|
116 |
+
# If we are using a pre-tokenized modality, we default to pre-computed crop settings
|
117 |
+
if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]):
|
118 |
+
modalities_without_vq.append("crop_settings")
|
119 |
+
modality_transforms = copy.deepcopy(modality_transforms)
|
120 |
+
modality_transforms["crop_settings"] = CropSettingsTransform()
|
121 |
+
|
122 |
+
if load_mask_valid:
|
123 |
+
modalities_without_vq.append("mask_valid")
|
124 |
+
modality_transforms = copy.deepcopy(modality_transforms)
|
125 |
+
modality_transforms["mask_valid"] = MaskTransform()
|
126 |
+
|
127 |
+
modality_paths = {mod: modality_info[mod]['path'] for mod in modality_info if modality_info[mod].get('path', None) is not None}
|
128 |
+
|
129 |
+
return MultiModalDatasetFolder(root=data_path, modalities=modalities_without_vq, modality_paths=modality_paths,
|
130 |
+
modality_transforms=modality_transforms, transform=transform, max_samples=max_samples,
|
131 |
+
pre_shuffle=pre_shuffle, cache=cache)
|
132 |
+
|
133 |
+
|
134 |
+
### Webdatasets (wds) functions
|
135 |
+
|
136 |
+
def _keyless_map(data, f, handler=reraise_exception):
|
137 |
+
"""Map samples without adding __key__."""
|
138 |
+
for sample in data:
|
139 |
+
try:
|
140 |
+
result = f(sample)
|
141 |
+
except Exception as exn:
|
142 |
+
if handler(exn):
|
143 |
+
continue
|
144 |
+
else:
|
145 |
+
break
|
146 |
+
if result is None:
|
147 |
+
continue
|
148 |
+
yield result
|
149 |
+
|
150 |
+
map = pipelinefilter(_keyless_map)
|
151 |
+
|
152 |
+
def check_dots(s):
|
153 |
+
if '.gz' in s:
|
154 |
+
return s.count('.') == 2
|
155 |
+
return s.count('.') == 1
|
156 |
+
|
157 |
+
def remove_ext_with_gz(s):
|
158 |
+
if s.endswith('.gz'):
|
159 |
+
s = s.replace(".gz", "")
|
160 |
+
return os.path.splitext(s)[0]
|
161 |
+
|
162 |
+
def wds_decoder(key, value):
|
163 |
+
if key == "png" or key.endswith(".png"):
|
164 |
+
img = Image.open(io.BytesIO(value))
|
165 |
+
return img
|
166 |
+
elif key == "jpg" or key.endswith(".jpg"):
|
167 |
+
img = Image.open(io.BytesIO(value))
|
168 |
+
return img
|
169 |
+
elif key == "jpeg" or key.endswith(".jpeg"):
|
170 |
+
img = Image.open(io.BytesIO(value))
|
171 |
+
return img
|
172 |
+
elif key == 'npy' or key.endswith("npy"):
|
173 |
+
content = np.load(io.BytesIO(value), allow_pickle=True)
|
174 |
+
# try:
|
175 |
+
# content = np.load(io.BytesIO(value))
|
176 |
+
# except:
|
177 |
+
# content = np.load(io.BytesIO(value), allow_pickle=True)
|
178 |
+
return content
|
179 |
+
elif key == "jpx" or key.endswith('.jpx'):
|
180 |
+
img = Image.open(io.BytesIO(value))
|
181 |
+
return img
|
182 |
+
elif 'output' in key:
|
183 |
+
return int(value)
|
184 |
+
else:
|
185 |
+
# If not an image, use the basic handlers (.txt, .json, .pickle, .npz, ...)
|
186 |
+
# See https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
|
187 |
+
return None
|
188 |
+
|
189 |
+
def repeat_fn(src, n_repeats=5):
|
190 |
+
"""
|
191 |
+
Repeat each sample n_repeats times.
|
192 |
+
E.g. A B C ... repeated 3 times becomes A A A B B B C C C ...
|
193 |
+
Depending on the downstream application, a shuffle should be added after this.
|
194 |
+
"""
|
195 |
+
for sample in src:
|
196 |
+
for _ in range(n_repeats):
|
197 |
+
yield sample
|
198 |
+
|
199 |
+
def remove_extensions(sample):
|
200 |
+
"""
|
201 |
+
In webdatasets, we identify the type of a given modality by adding an extension
|
202 |
+
in the form f"{modality_name}.{modality_extension}", e.g. "rgb.jpg" or "caption.json".
|
203 |
+
This function removes them and returns a dictionary of {f"{modality_name}": modality}.
|
204 |
+
"""
|
205 |
+
return {remove_ext_with_gz(k): v for k, v in sample.items()}
|
206 |
+
|
207 |
+
def filter_metadata(sample, metadata=['__key__', '__url__', 'file_name', 'class_name', 'class_idx']):
|
208 |
+
""" Filters out non-modality entries specified in metadata when loading tar files with webdatasets. """
|
209 |
+
return {k: v for k, v in sample.items() if k not in metadata}
|
210 |
+
|
211 |
+
def apply_modality_transforms(sample, modality_transforms):
|
212 |
+
""" Applies a dictionary of modality-specific transforms to a dictionary of modalities. """
|
213 |
+
return {k: (modality_transforms[get_transform_key(k)](v) if k in modality_transforms else v) for k, v in sample.items() }
|
214 |
+
|
215 |
+
def tok_to_int64(sample):
|
216 |
+
"""
|
217 |
+
Pre-computed tokens are saved as int16, but we need them as int64 instead.
|
218 |
+
"""
|
219 |
+
return {k: (v.astype('int64') if 'tok_' in k else v) for k, v in sample.items()}
|
220 |
+
|
221 |
+
def rename_modalities(sample, modality_paths):
|
222 |
+
"""
|
223 |
+
Renames modalities to their corresponding names in modality_paths.
|
224 |
+
"""
|
225 |
+
return {out_path: sample[loaded_path] for out_path, loaded_path in modality_paths.items()}
|
226 |
+
|
227 |
+
def extract_modality_names(s):
|
228 |
+
# Regular expression pattern to match anything enclosed in '{' and '}', and comma separated
|
229 |
+
pattern = r'\{([^}]*)\}'
|
230 |
+
match = re.search(pattern, s)
|
231 |
+
return match.group(1).split(',') if match else []
|
232 |
+
|
233 |
+
def identity(sample):
|
234 |
+
""" Identity function that does nothing. """
|
235 |
+
return sample
|
236 |
+
|
237 |
+
def multi_tarfile_samples(src_iter: Iterable[Dict[str, Any]],
|
238 |
+
modality_name_map: Dict[str, str] = None,
|
239 |
+
handler: Callable[[Exception], bool] = warn_and_continue):
|
240 |
+
"""Webdataset does not support splitting up shards by modality, so we need to do this manually.
|
241 |
+
Usually, we would need to save all modalities in the same tar file, e.g. shard_root_train/{00000..12345}.tar,
|
242 |
+
where each shard contains 1000 samples and each sample contains all modalities.
|
243 |
+
This is not flexible when adding new modalities, so we instead save each modality in a separate tar file,
|
244 |
+
e.g. shard_root_train_rgb/{00000..12345}.tar, shard_root_train_caption/{00000..12345}.tar, etc., where each shard contains
|
245 |
+
again 1000 samples, but each sample contains only one modality. All samples in all shards have to be aligned.
|
246 |
+
|
247 |
+
This function takes an iterator over shard URLs, where we use brace expansion to specify multiple tar files per modality.
|
248 |
+
E.g. shard_root_train_[rgb,caption]/00123.tar will be expanded to shard_root_train_rgb/00123.tar and shard_root_train_caption/00123.tar,
|
249 |
+
and the samples from these two tar files will be combined into a single sample.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
src_iter: Iterator over shards that *already brace expanded the shard numbers*,
|
253 |
+
e.g. {'url': 'shard_root_train_[rgb,caption]/00000.tar'}, {'url': 'shard_root_train_[rgb,caption]/00001.tar'}, ...
|
254 |
+
This function will also work when no square braces for multiple modalities are used, e.g. {'url': 'shard_root_train/00000.tar'}, ...
|
255 |
+
It can be a drop-in replacement for wds.tarfile_samples.
|
256 |
+
modality_name_map: Optional dictionary specifying a mapping from modality folder names to arbitrary other names.
|
257 |
+
handler: Function that handles exceptions. If it returns True, the shard is skipped. If it returns False, the function exits.
|
258 |
+
|
259 |
+
Yields:
|
260 |
+
Dictionary of aligned samples from all modalities.
|
261 |
+
"""
|
262 |
+
for src in src_iter:
|
263 |
+
|
264 |
+
# Multi tar file URLs use brace expansion with square braces
|
265 |
+
multi_tar_urls = src['url'].translate(str.maketrans('[]', '{}'))
|
266 |
+
modality_names = extract_modality_names(multi_tar_urls)
|
267 |
+
if len(modality_names) == 0:
|
268 |
+
# Case where multi-modal braceexpand is not used, e.g. shard_dir/shard00000.tar
|
269 |
+
modality_names = [None]
|
270 |
+
multi_tar_urls = [multi_tar_urls]
|
271 |
+
elif len(modality_names) == 1:
|
272 |
+
# Brace expand doesn't work with a single entry, e.g. shard_dir/[foo]/shard00000.tar
|
273 |
+
multi_tar_urls = [multi_tar_urls.replace("{", "").replace("}", "")]
|
274 |
+
else:
|
275 |
+
# Remaining cases where multiple modalities are specified, e.g. shard_dir/[foo,bar]/shard00000.tar
|
276 |
+
multi_tar_urls = list(braceexpand.braceexpand(multi_tar_urls))
|
277 |
+
|
278 |
+
# Create tar iterators for shards of all modalities
|
279 |
+
tar_iters = [wds.tarfile_samples([{'url': tar_url}]) for tar_url in multi_tar_urls]
|
280 |
+
|
281 |
+
try:
|
282 |
+
# Loop over these iterators in parallel and combine the tar files from different modalities
|
283 |
+
for multi_tar_files in zip(*tar_iters):
|
284 |
+
|
285 |
+
merged_dict = {}
|
286 |
+
merged_dict['__key__'] = multi_tar_files[0]['__key__']
|
287 |
+
merged_dict['__url__'] = src['url']
|
288 |
+
|
289 |
+
for modality_name, modality_dict in zip(modality_names, multi_tar_files):
|
290 |
+
_key = modality_dict.pop('__key__')
|
291 |
+
_url = modality_dict.pop('__url__')
|
292 |
+
|
293 |
+
if _key != merged_dict['__key__']:
|
294 |
+
raise ValueError(f"Divergence detected! Trying to merge keys {_key} of {modality_name} and {merged_dict['__key__']} of merged_dict with modalities {merged_dict.keys()}.")
|
295 |
+
|
296 |
+
tar_is_multimodal = len(modality_dict) > 1
|
297 |
+
for k, v in modality_dict.items():
|
298 |
+
if tar_is_multimodal or check_dots(k) or modality_name is None:
|
299 |
+
# We don't change the keys in the following cases:
|
300 |
+
# 1. The shard contains multiple modalities. Then they *have* to follow the idx.modality_id.ext convention
|
301 |
+
# 2. If any key contains a dot, this means it already has the idx.modality_id.ext format (idx. is already removed at this stage)
|
302 |
+
# 3. If the modality name is None, no modality folder was specified (see beginning of function)
|
303 |
+
merged_dict[k] = v
|
304 |
+
else:
|
305 |
+
mapped_name = modality_name if modality_name_map is None else modality_name_map.get(modality_name, modality_name)
|
306 |
+
merged_dict[f'{mapped_name}.{k}'] = v
|
307 |
+
|
308 |
+
yield merged_dict
|
309 |
+
|
310 |
+
except Exception as e:
|
311 |
+
print(e)
|
312 |
+
print(f"Exception occurred while processing {src['url']}.")
|
313 |
+
if handler(e):
|
314 |
+
print('Skipping shard...')
|
315 |
+
continue
|
316 |
+
else:
|
317 |
+
break
|
318 |
+
|
319 |
+
def build_wds_fm_pretraining_dataloader(
|
320 |
+
data_path, all_domains, modality_info, modality_transforms, image_augmenter,
|
321 |
+
text_tokenizer, input_tokens_range, target_tokens_range,
|
322 |
+
num_gpus, num_workers, batch_size, epoch_size, sampling_weights=None, modality_name_map=None,
|
323 |
+
shuffle_buffer_load=1000, shuffle_buffer_repeat=5000, n_repeats=5):
|
324 |
+
"""Builds the WebDataset FourM pre-training dataloader based on the given arguments.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
data_path: Path to the dataset.
|
328 |
+
all_domains: List of all modalities to be used.
|
329 |
+
modality_info: Dictionary containing information about the modalities.
|
330 |
+
modality_transforms: Dictionary containing the transforms for each modality.
|
331 |
+
image_augmenter: Image augmenter.
|
332 |
+
text_tokenizer: Text tokenizer (for sequence modalities).
|
333 |
+
input_tokens_range: Range of the input token budget.
|
334 |
+
target_tokens_range: Range of the target token budget.
|
335 |
+
num_gpus: Number of GPUs.
|
336 |
+
num_workers: Number of workers.
|
337 |
+
batch_size: Batch size.
|
338 |
+
epoch_size: Number of samples per "epoch". (Here, epoch refers to an interrupted training loop without evaluation or checkpointing).
|
339 |
+
sampling_weights: Sampling weights for the mixture of Dirichlet distributions.
|
340 |
+
modality_name_map: Optional dictionary specifying a mapping from modality folder names to arbitrary other names.
|
341 |
+
shuffle_buffer_load: Shuffle buffer size when loading samples from tar files (first shuffle).
|
342 |
+
shuffle_buffer_repeat: Shuffle buffer size after repeating samples (second shuffle).
|
343 |
+
n_repeats: Number of times to repeat each sample.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
FourM pre-training dataloader as a WebDataset DataLoader.
|
347 |
+
"""
|
348 |
+
|
349 |
+
modality_paths = {mod: modality_info[mod].get('path', None) or mod for mod in modality_info}
|
350 |
+
|
351 |
+
# Remove vq domains that require a tokenizer
|
352 |
+
modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)]
|
353 |
+
# If we are using a pre-tokenized modality, we default to pre-computed crop settings
|
354 |
+
if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]):
|
355 |
+
modalities_without_vq.append("crop_settings")
|
356 |
+
modality_transforms = copy.deepcopy(modality_transforms)
|
357 |
+
modality_transforms["crop_settings"] = CropSettingsTransform()
|
358 |
+
modality_paths["crop_settings"] = "crop_settings"
|
359 |
+
|
360 |
+
# Webdatasets always adds __key__ to the dictionary, so we add a transform that does nothing with it
|
361 |
+
modality_transforms["__key__"] = IdentityTransform()
|
362 |
+
|
363 |
+
transform = transforms.Compose([
|
364 |
+
UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter),
|
365 |
+
UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer,
|
366 |
+
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
|
367 |
+
sampling_weights=sampling_weights)
|
368 |
+
])
|
369 |
+
|
370 |
+
datapipe = wds.DataPipeline(
|
371 |
+
# Infinitely sample shards from the shard list with replacement. Each worker is seeded independently.
|
372 |
+
wds.ResampledShards(data_path),
|
373 |
+
partial(multi_tarfile_samples, modality_name_map=modality_name_map), # Extract individual samples from single or multi-modal tar files
|
374 |
+
wds.shuffle(shuffle_buffer_load), # Shuffle with a buffer of given size
|
375 |
+
wds.decode(wds_decoder), # Decode from bytes to PIL images, numpy arrays, etc.
|
376 |
+
wds.filters.compose(partial(repeat_fn, n_repeats=n_repeats)), # Repeats each sample n times -> A A A B B B C C C ...
|
377 |
+
wds.shuffle(shuffle_buffer_repeat), # Shuffle again with a buffer of given size
|
378 |
+
wds.map(remove_extensions), # Remove "file extensions" from dictionary keys
|
379 |
+
map(filter_metadata), # Remove non-task keys
|
380 |
+
map(tok_to_int64), # Convert pre-computed tokens to int64
|
381 |
+
map(partial(rename_modalities, modality_paths=modality_paths)), # Rename modalities to their corresponding names in modality_paths
|
382 |
+
map(transform), # Apply data augmentation and masking
|
383 |
+
wds.batched(batch_size, collation_fn=default_collate, partial=False)
|
384 |
+
if batch_size is not None else map(identity), # Batching
|
385 |
+
)
|
386 |
+
|
387 |
+
if epoch_size is not None:
|
388 |
+
batch_size_iter = batch_size if batch_size is not None else 1
|
389 |
+
datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length
|
390 |
+
|
391 |
+
if batch_size is not None:
|
392 |
+
# Perform multi-threaded dataloading
|
393 |
+
return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None)
|
394 |
+
else:
|
395 |
+
return datapipe
|
396 |
+
|
397 |
+
|
398 |
+
def build_wds_divae_dataloader(
|
399 |
+
data_path, modality_info, modality_transforms, image_augmenter,
|
400 |
+
num_gpus, num_workers, batch_size, epoch_size, shuffle_buffer_load=1000,
|
401 |
+
shuffle_buffer_repeat=5000, n_repeats=1):
|
402 |
+
|
403 |
+
modality_paths = {mod: modality_info[mod].get('path', None) or mod for mod in modality_info}
|
404 |
+
|
405 |
+
# Webdatasets always adds __key__ to the dictionary, so we add a transform that does nothing with it
|
406 |
+
modality_transforms["__key__"] = IdentityTransform()
|
407 |
+
|
408 |
+
transform = UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter)
|
409 |
+
|
410 |
+
datapipe = wds.DataPipeline(
|
411 |
+
# Infinitely sample shards from the shard list with replacement. Each worker is seeded independently.
|
412 |
+
wds.ResampledShards(data_path),
|
413 |
+
multi_tarfile_samples, # Extract individual samples from single or multi-modal tar files
|
414 |
+
wds.shuffle(shuffle_buffer_load), # Shuffle with a buffer of given size
|
415 |
+
wds.decode(wds_decoder), # Decode from bytes to PIL images, numpy arrays, etc.
|
416 |
+
wds.filters.compose(partial(repeat_fn, n_repeats=n_repeats)), # Repeats each sample n times -> A A A B B B C C C ...
|
417 |
+
wds.shuffle(shuffle_buffer_repeat), # Shuffle again with a buffer of given size
|
418 |
+
map(remove_extensions), # Remove "file extensions" from dictionary keys
|
419 |
+
map(filter_metadata), # Remove non-task keys
|
420 |
+
map(tok_to_int64), # Convert pre-computed tokens to int64
|
421 |
+
map(partial(rename_modalities, modality_paths=modality_paths)), # Rename modalities to their corresponding names in modality_paths
|
422 |
+
map(transform), # Apply data augmentation and masking
|
423 |
+
wds.batched(batch_size, collation_fn=default_collate, partial=False)
|
424 |
+
if batch_size is not None else map(identity), # Batching
|
425 |
+
)
|
426 |
+
|
427 |
+
if epoch_size is not None:
|
428 |
+
batch_size_iter = batch_size if batch_size is not None else 1
|
429 |
+
datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length
|
430 |
+
|
431 |
+
if batch_size is not None:
|
432 |
+
# Perform multi-threaded dataloading
|
433 |
+
return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None)
|
434 |
+
else:
|
435 |
+
return datapipe
|
436 |
+
|
437 |
+
|
438 |
+
### Huggingface datasets functions
|
439 |
+
|
440 |
+
def text_to_caption(sample):
|
441 |
+
""" Rename "text" to "caption". """
|
442 |
+
return {'caption': sample['text']}
|
443 |
+
|
444 |
+
|
445 |
+
def build_huggingface_pretraining_dataloader(
|
446 |
+
data_path, all_domains, modality_info, modality_transforms, image_augmenter,
|
447 |
+
text_tokenizer, input_tokens_range, target_tokens_range,
|
448 |
+
num_gpus, num_workers, batch_size, epoch_size, split,
|
449 |
+
streaming=True, rename_text_to_caption=True, shuffle_buffer_load=10_000, shuffle_seed=0):
|
450 |
+
|
451 |
+
# Load huggingface dataset and split samples across workers. Shuffle samples in each worker
|
452 |
+
dataset = load_dataset(data_path, split=split, streaming=streaming)
|
453 |
+
dataset = split_dataset_by_node(dataset, rank=get_rank(), world_size=get_world_size())
|
454 |
+
dataset = dataset.shuffle(seed=shuffle_seed, buffer_size=shuffle_buffer_load)
|
455 |
+
|
456 |
+
modality_info = {mod: modality_info[mod] for mod in modality_info if mod in all_domains}
|
457 |
+
|
458 |
+
transform = transforms.Compose([
|
459 |
+
UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter),
|
460 |
+
UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer,
|
461 |
+
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range)
|
462 |
+
])
|
463 |
+
|
464 |
+
datapipe = wds.DataPipeline(
|
465 |
+
dataset,
|
466 |
+
map(text_to_caption) if rename_text_to_caption else map(identity), # Rename "text" to "caption"
|
467 |
+
map(filter_metadata), # Remove non-task keys
|
468 |
+
map(transform), # Apply data augmentation and masking
|
469 |
+
wds.batched(batch_size, collation_fn=default_collate, partial=False)
|
470 |
+
if batch_size is not None else map(identity), # Batching
|
471 |
+
)
|
472 |
+
|
473 |
+
datapipe.n_shards = dataset.n_shards
|
474 |
+
num_workers = min(num_workers, dataset.n_shards)
|
475 |
+
|
476 |
+
if epoch_size is not None:
|
477 |
+
batch_size_iter = batch_size if batch_size is not None else 1
|
478 |
+
datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length
|
479 |
+
|
480 |
+
if batch_size is not None:
|
481 |
+
# Perform multi-threaded dataloading
|
482 |
+
return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None)
|
483 |
+
else:
|
484 |
+
return datapipe
|
485 |
+
|
486 |
+
|
487 |
+
### Multi-dataset loading utils
|
488 |
+
def make_empty_mod_dict(modality_info):
|
489 |
+
empty_mod_dicts = {}
|
490 |
+
|
491 |
+
for mod_name, mod_info in modality_info.items():
|
492 |
+
empty_mod = {}
|
493 |
+
|
494 |
+
# Tensor
|
495 |
+
if 'num_channels' in mod_info and 'input_size' in mod_info:
|
496 |
+
# Handle image-like modalities
|
497 |
+
max_tokens = mod_info['max_tokens']
|
498 |
+
empty_mod['tensor'] = torch.zeros((mod_info['num_channels'], mod_info['input_size'], mod_info['input_size']), dtype=torch.float32)
|
499 |
+
elif mod_name == 't5_caption':
|
500 |
+
# Handle T5 embedding
|
501 |
+
max_tokens = mod_info['max_tokens']
|
502 |
+
orig_emb_dim = mod_info['encoder_embedding']().orig_emb_dim
|
503 |
+
empty_mod['tensor'] = torch.zeros((max_tokens, orig_emb_dim), dtype=torch.float32)
|
504 |
+
elif mod_info['type'] in ['seq', 'seq_emb', 'seq_token']:
|
505 |
+
# Handle all other discrete sequence modalities
|
506 |
+
max_tokens = (mod_info['max_tokens'] + 1) * 2
|
507 |
+
empty_mod['tensor'] = torch.zeros((max_tokens), dtype=torch.int32)
|
508 |
+
else:
|
509 |
+
max_tokens = mod_info['max_tokens']
|
510 |
+
empty_mod['tensor'] = torch.zeros((max_tokens), dtype=torch.int32)
|
511 |
+
|
512 |
+
# Input and target masks
|
513 |
+
empty_mod['input_mask'] = torch.ones((max_tokens), dtype=torch.bool)
|
514 |
+
empty_mod['target_mask'] = torch.ones((max_tokens), dtype=torch.bool)
|
515 |
+
|
516 |
+
# Decoder attention mask
|
517 |
+
empty_mod['decoder_attention_mask'] = torch.zeros((max_tokens), dtype=torch.int32)
|
518 |
+
|
519 |
+
empty_mod_dicts[mod_name] = empty_mod
|
520 |
+
|
521 |
+
return empty_mod_dicts
|
522 |
+
|
523 |
+
|
524 |
+
class MixtureDataset(IterableDataset):
|
525 |
+
def __init__(self, data_iters, weights, modality_info):
|
526 |
+
self.orig_data_iters = data_iters
|
527 |
+
self.data_iters = [iter(data_iter) for data_iter in data_iters] # Create initial iterators
|
528 |
+
self.sampling_probs = np.array(weights) / sum(weights)
|
529 |
+
self.modality_info = modality_info
|
530 |
+
|
531 |
+
def reset_iterator(self, idx):
|
532 |
+
""" Reset the iterator when exhausted. """
|
533 |
+
self.data_iters[idx] = iter(self.orig_data_iters[idx])
|
534 |
+
|
535 |
+
def __iter__(self):
|
536 |
+
while True:
|
537 |
+
dataset_idx = np.random.choice(len(self.sampling_probs), p=self.sampling_probs)
|
538 |
+
try:
|
539 |
+
data = next(self.data_iters[dataset_idx])
|
540 |
+
except StopIteration: # If the iterator is exhausted
|
541 |
+
self.reset_iterator(dataset_idx) # Reset it
|
542 |
+
data = next(self.data_iters[dataset_idx])
|
543 |
+
|
544 |
+
mod_dict = make_empty_mod_dict(self.modality_info)
|
545 |
+
mod_dict.update(data)
|
546 |
+
yield mod_dict
|
547 |
+
|
548 |
+
|
549 |
+
def build_mixture_dataloader(data_iters, weights, modality_info, batch_size, num_workers, epoch_size, num_gpus):
|
550 |
+
mixture_pipe = wds.DataPipeline(
|
551 |
+
MixtureDataset(data_iters, weights, modality_info),
|
552 |
+
wds.batched(batch_size, collation_fn=default_collate, partial=False),
|
553 |
+
).with_epoch(epoch_size // (num_gpus * num_workers * batch_size)) # Pre-define iterator length
|
554 |
+
|
555 |
+
mixture_loader = wds.WebLoader(mixture_pipe, num_workers=num_workers, batch_size=None)
|
556 |
+
|
557 |
+
return mixture_loader
|
fourm/demo_4M_sampler.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import requests
|
9 |
+
from tokenizers import Tokenizer
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from torchvision.transforms.functional import center_crop
|
12 |
+
|
13 |
+
from fourm.models.fm import FM
|
14 |
+
from fourm.vq.vqvae import VQVAE, DiVAE
|
15 |
+
from fourm.models.generate import GenerationSampler, build_chained_generation_schedules, init_empty_target_modality, init_full_input_modality, custom_text
|
16 |
+
from fourm.utils.plotting_utils import decode_dict
|
17 |
+
from fourm.data.modality_info import MODALITY_INFO
|
18 |
+
from fourm.data.modality_transforms import RGBTransform
|
19 |
+
from fourm.utils import load_safetensors
|
20 |
+
from fourm.utils.plotting_utils import decode_dict, visualize_bboxes, plot_text_in_square, text_to_pil_image
|
21 |
+
|
22 |
+
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
|
23 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
24 |
+
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
|
25 |
+
torch.backends.cudnn.allow_tf32 = True
|
26 |
+
|
27 |
+
|
28 |
+
# Default chained generation order
|
29 |
+
DEFAULT_ORDER = [
|
30 |
+
'tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224', 'tok_depth@224', 'tok_normal@224',
|
31 |
+
'tok_semseg@224', 'tok_canny_edge@224', 'tok_sam_edge@224', 'tok_rgb@224',
|
32 |
+
'caption', 'det', 'human_poses', 'sam_instance', 'color_palette', 'metadata',
|
33 |
+
]
|
34 |
+
|
35 |
+
# Default super-resolution chained generation order
|
36 |
+
DEFAULT_ORDER_SR = [
|
37 |
+
'tok_clip@448', 'tok_depth@448', 'tok_normal@448',
|
38 |
+
'tok_semseg@448', 'tok_rgb@448',
|
39 |
+
]
|
40 |
+
|
41 |
+
# Default generation parameters for the case where the input contains RGB
|
42 |
+
DEFAULTS_RGB2X = {
|
43 |
+
'tok_clip@224/tok_depth@224/tok_normal@224/tok_semseg@224/tok_canny_edge@224/tok_sam_edge@224': {
|
44 |
+
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 1,
|
45 |
+
'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant',
|
46 |
+
'cfg_scale': 2.0, 'cfg_schedule': 'constant',
|
47 |
+
},
|
48 |
+
'tok_dinov2@224/tok_imagebind@224': {
|
49 |
+
'tokens_per_target': 256, 'autoregression_scheme': 'roar', 'decoding_steps': 1,
|
50 |
+
'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant',
|
51 |
+
'cfg_scale': 2.0, 'cfg_schedule': 'constant',
|
52 |
+
},
|
53 |
+
'caption/det': {
|
54 |
+
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
55 |
+
'token_decoding_schedule': None, 'temp': 0.3, 'temp_schedule': 'constant',
|
56 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
57 |
+
},
|
58 |
+
'human_poses': {
|
59 |
+
'tokens_per_target': 275, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
60 |
+
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
|
61 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
62 |
+
},
|
63 |
+
'sam_instance': {
|
64 |
+
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
65 |
+
'token_decoding_schedule': None, 'temp': 0.01, 'temp_schedule': 'constant',
|
66 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
67 |
+
},
|
68 |
+
'color_palette': {
|
69 |
+
'tokens_per_target': 23, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
70 |
+
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
|
71 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
72 |
+
},
|
73 |
+
'metadata': {
|
74 |
+
'tokens_per_target': 40, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
75 |
+
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
|
76 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
77 |
+
},
|
78 |
+
}
|
79 |
+
|
80 |
+
# Default generation parameters for the case where the target is RGB
|
81 |
+
DEFAULTS_X2RGB = {
|
82 |
+
'tok_clip@224': {
|
83 |
+
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 50,
|
84 |
+
'token_decoding_schedule': 'linear', 'temp': 5.0, 'temp_schedule': 'onex:0.5:0.5',
|
85 |
+
'cfg_scale': 3.0, 'cfg_schedule': 'constant',
|
86 |
+
},
|
87 |
+
'tok_dinov2@224/tok_imagebind@224': {
|
88 |
+
'tokens_per_target': 256, 'autoregression_scheme': 'roar', 'decoding_steps': 8,
|
89 |
+
'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant',
|
90 |
+
'cfg_scale': 2.0, 'cfg_schedule': 'constant',
|
91 |
+
},
|
92 |
+
'tok_depth@224/tok_normal@224/tok_semseg@224/tok_canny_edge@224/tok_sam_edge@224': {
|
93 |
+
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 8,
|
94 |
+
'token_decoding_schedule': 'linear', 'temp': 3.0, 'temp_schedule': 'onex:0.5:0.5',
|
95 |
+
'cfg_scale': 2.0, 'cfg_schedule': 'constant',
|
96 |
+
},
|
97 |
+
'tok_rgb@224': {
|
98 |
+
'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 25,
|
99 |
+
'token_decoding_schedule': 'linear', 'temp': 3.0, 'temp_schedule': 'onex:0.5:0.5',
|
100 |
+
'cfg_scale': 2.0, 'cfg_schedule': 'constant',
|
101 |
+
},
|
102 |
+
'caption/det': {
|
103 |
+
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
104 |
+
'token_decoding_schedule': None, 'temp': 0.3, 'temp_schedule': 'constant',
|
105 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
106 |
+
},
|
107 |
+
'human_poses': {
|
108 |
+
'tokens_per_target': 275, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
109 |
+
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
|
110 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
111 |
+
},
|
112 |
+
'sam_instance': {
|
113 |
+
'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
114 |
+
'token_decoding_schedule': None, 'temp': 0.01, 'temp_schedule': 'constant',
|
115 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
116 |
+
},
|
117 |
+
'color_palette': {
|
118 |
+
'tokens_per_target': 23, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
119 |
+
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
|
120 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
121 |
+
},
|
122 |
+
'metadata': {
|
123 |
+
'tokens_per_target': 40, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
|
124 |
+
'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
|
125 |
+
'cfg_scale': 1.0, 'cfg_schedule': 'constant',
|
126 |
+
},
|
127 |
+
}
|
128 |
+
|
129 |
+
# Default generation parameters for super-resolution
|
130 |
+
DEFAULTS_SR = {
|
131 |
+
'tok_clip@448/tok_depth@448/tok_normal@448/tok_semseg@448/tok_rgb@448': {
|
132 |
+
'tokens_per_target': 784, 'autoregression_scheme': 'maskgit', 'decoding_steps': 8,
|
133 |
+
'token_decoding_schedule': 'cosine', 'temp': 1.0, 'temp_schedule': 'constant',
|
134 |
+
'cfg_scale': 2.0, 'cfg_schedule': 'constant',
|
135 |
+
},
|
136 |
+
}
|
137 |
+
|
138 |
+
# Plotting names for each modality
|
139 |
+
MODALITY_PLOTTING_NAME_MAP = {
|
140 |
+
'caption': 'Caption',
|
141 |
+
'det': 'Bounding boxes',
|
142 |
+
'human_poses': 'Human poses',
|
143 |
+
'sam_instance': 'SAM instances (single pass)',
|
144 |
+
'color_palette': 'Color palette',
|
145 |
+
'metadata': 'Metadata',
|
146 |
+
'rgb@224': 'RGB (224x224)',
|
147 |
+
'rgb@448': 'RGB (448x448)',
|
148 |
+
'tok_rgb@224': 'RGB (tokenized, 224x224)',
|
149 |
+
'tok_rgb@448': 'RGB (tokenized, 448x448)',
|
150 |
+
'tok_clip@224': 'CLIP-B/16 (224x224)',
|
151 |
+
'tok_clip@448': 'CLIP-B/16 (448x448)',
|
152 |
+
'tok_depth@224': 'Depth (224x224)',
|
153 |
+
'tok_depth@448': 'Depth (448x448)',
|
154 |
+
'tok_normal@224': 'Normals (224x224)',
|
155 |
+
'tok_normal@448': 'Normals (448x448)',
|
156 |
+
'tok_semseg@224': 'Semantic segmentation (224x224)',
|
157 |
+
'tok_semseg@448': 'Semantic segmentation (448x448)',
|
158 |
+
'tok_canny_edge@224': 'Canny edges (224x224)',
|
159 |
+
'tok_sam_edge@224': 'SAM edges (224x224)',
|
160 |
+
'tok_dinov2@224': 'DINOv2-B/14 (224x224)',
|
161 |
+
'tok_imagebind@224': 'ImageBind-H/14 (224x224)',
|
162 |
+
}
|
163 |
+
|
164 |
+
# Optional fixed plotting order (by default, plotting order is determined by generation order)
|
165 |
+
MODALITY_PLOTTING_ORDER = [
|
166 |
+
'rgb@224', 'rgb@448', 'tok_rgb@224', 'tok_rgb@448',
|
167 |
+
'tok_depth@224', 'tok_depth@448', 'tok_normal@224', 'tok_normal@448',
|
168 |
+
'tok_semseg@224', 'tok_semseg@448', 'tok_canny_edge@224', 'tok_sam_edge@224',
|
169 |
+
'sam_instance', 'human_poses', 'det', 'caption', 'metadata', 'color_palette',
|
170 |
+
'tok_clip@224', 'tok_clip@448', 'tok_dinov2@224', 'tok_imagebind@224',
|
171 |
+
]
|
172 |
+
|
173 |
+
|
174 |
+
def get_value(defaults_dict, domain, key):
|
175 |
+
"""Look up a default value belonging to a given domain and key."""
|
176 |
+
for domains, defaults in defaults_dict.items():
|
177 |
+
if domain in domains:
|
178 |
+
return defaults[key]
|
179 |
+
|
180 |
+
def load_model(model_id, model_class):
|
181 |
+
"""Load a model from HuggingFace hub or a given .safetensors checkpoint path."""
|
182 |
+
if model_id.endswith('.safetensors'):
|
183 |
+
ckpt, config = load_safetensors(model_id)
|
184 |
+
model = model_class(config=config)
|
185 |
+
model.load_state_dict(ckpt)
|
186 |
+
else:
|
187 |
+
model = model_class.from_pretrained(model_id)
|
188 |
+
return model
|
189 |
+
|
190 |
+
def img_from_url(url: str):
|
191 |
+
rgb_transform = RGBTransform(imagenet_default_mean_and_std=True)
|
192 |
+
img_data = requests.get(url).content
|
193 |
+
with open('demo.png', 'wb') as handler:
|
194 |
+
handler.write(img_data)
|
195 |
+
img_pil = rgb_transform.load('./demo.png')
|
196 |
+
img_pil = rgb_transform.preprocess(img_pil)
|
197 |
+
img_pil = center_crop(img_pil, (min(img_pil.size), min(img_pil.size))).resize((224,224))
|
198 |
+
img = rgb_transform.postprocess(img_pil).unsqueeze(0)
|
199 |
+
return img
|
200 |
+
|
201 |
+
|
202 |
+
class Demo4MSampler(nn.Module):
|
203 |
+
"""Convenience wrapper for easy 4M loading and generation. Users can specify HuggingFace Hub
|
204 |
+
model URLs, or downloaded safetensors checkpoints paths, and the models will be automatically
|
205 |
+
loaded. The `forward` function can be used for RGB-2-all and {caption,det}-2-all generation.
|
206 |
+
This wrapper is only intended for quickly trying out 4M models. For more advanced usecases we
|
207 |
+
recommend looking at the generation notebooks in `./notebooks/`, and `./run_generation.py`.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
fm: Hub or safetensors path of 4M base model
|
211 |
+
fm_sr: Hub or safetensors path of 4M super-resolution model
|
212 |
+
tok_rgb: Hub or safetensors path of RGB tokenizer
|
213 |
+
tok_depth: Hub or safetensors path of depth tokenizer
|
214 |
+
tok_normal: Hub or safetensors path of surface normal tokenizer
|
215 |
+
tok_edge: Hub or safetensors path of canny edge tokenizer (for SAM and RGB edges)
|
216 |
+
tok_semseg: Hub or safetensors path of COCO semantic segmentation tokenizer
|
217 |
+
tok_clip: Hub or safetensors path of CLIP-B/16 tokenizer
|
218 |
+
tok_dinov2: Hub or safetensors path of DINOv2-B/14 tokenizer
|
219 |
+
tok_imagebind: Hub or safetensors path of ImageBind-H/14 tokenizer
|
220 |
+
tok_sam_instance: Hub or safetensors path of SAM instance tokenizer
|
221 |
+
tok_human_poses: Hub or safetensors path of human poses tokenizer
|
222 |
+
tok_text: Path to text tokenizer JSON file
|
223 |
+
mods: Optional list of modalities to override default behavior of generating everything
|
224 |
+
mods_sr: Optional list of super-res modalities to override default behavior of generating everything
|
225 |
+
"""
|
226 |
+
def __init__(self,
|
227 |
+
fm: str = 'EPFL-VILAB/4M-21_XL_CC12M',
|
228 |
+
fm_sr: Optional[str] = 'EPFL-VILAB/4M-7-SR_L_CC12M',
|
229 |
+
tok_rgb: Optional[str] = 'EPFL-VILAB/4M_tokenizers_rgb_16k_224-448',
|
230 |
+
tok_depth: Optional[str] = 'EPFL-VILAB/4M_tokenizers_depth_8k_224-448',
|
231 |
+
tok_normal: Optional[str] = 'EPFL-VILAB/4M_tokenizers_normal_8k_224-448',
|
232 |
+
tok_edge: Optional[str] = 'EPFL-VILAB/4M_tokenizers_edge_8k_224-512',
|
233 |
+
tok_semseg: Optional[str] = 'EPFL-VILAB/4M_tokenizers_semseg_4k_224-448',
|
234 |
+
tok_clip: Optional[str] = 'EPFL-VILAB/4M_tokenizers_CLIP-B16_8k_224-448',
|
235 |
+
tok_dinov2: Optional[str] = 'EPFL-VILAB/4M_tokenizers_DINOv2-B14_8k_224-448',
|
236 |
+
tok_imagebind: Optional[str] = 'EPFL-VILAB/4M_tokenizers_ImageBind-H14_8k_224-448',
|
237 |
+
tok_sam_instance: Optional[str] = 'EPFL-VILAB/4M_tokenizers_sam-instance_1k_64',
|
238 |
+
tok_human_poses: Optional[str] = 'EPFL-VILAB/4M_tokenizers_human-poses_1k_8',
|
239 |
+
tok_text: str = './fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json',
|
240 |
+
mods: Optional[List[str]] = None,
|
241 |
+
mods_sr: Optional[List[str]] = None,
|
242 |
+
verbose: bool = True):
|
243 |
+
super().__init__()
|
244 |
+
|
245 |
+
self.verbose = verbose
|
246 |
+
if self.verbose:
|
247 |
+
print('Loading 4M models and tokenizers...', end='')
|
248 |
+
|
249 |
+
# Load 4M model and initialize sampler
|
250 |
+
fm = load_model(fm, FM)
|
251 |
+
self.sampler_fm = GenerationSampler(fm)
|
252 |
+
self.mods = mods or list(set(fm.encoder_modalities) | set(fm.decoder_modalities))
|
253 |
+
|
254 |
+
# Load optional 4M super-res model and initialize sampler
|
255 |
+
if fm_sr is not None:
|
256 |
+
fm_sr = load_model(fm_sr, FM)
|
257 |
+
self.sampler_fm_sr = GenerationSampler(fm_sr)
|
258 |
+
self.mods_sr = mods_sr or list(set(fm_sr.encoder_modalities) | set(fm_sr.decoder_modalities))
|
259 |
+
else:
|
260 |
+
self.sampler_fm_sr = None
|
261 |
+
|
262 |
+
# Load tokenizers
|
263 |
+
self.toks = {}
|
264 |
+
if ('tok_rgb@224' in self.mods or 'tok_rgb@448' in self.mods_sr) and tok_rgb is not None:
|
265 |
+
self.toks['tok_rgb'] = load_model(tok_rgb, DiVAE)
|
266 |
+
if ('tok_depth@224' in self.mods or 'tok_depth@448' in self.mods_sr) and tok_depth is not None:
|
267 |
+
self.toks['tok_depth'] = load_model(tok_depth, DiVAE)
|
268 |
+
if ('tok_normal@224' in self.mods or 'tok_normal@448' in self.mods_sr) and tok_normal is not None:
|
269 |
+
self.toks['tok_normal'] = load_model(tok_normal, DiVAE)
|
270 |
+
if ('tok_canny_edge@224' in self.mods or 'tok_sam_edge@224' in self.mods) and tok_edge is not None:
|
271 |
+
self.toks['tok_canny_edge'] = load_model(tok_edge, DiVAE)
|
272 |
+
self.toks['tok_sam_edge'] = self.toks['tok_canny_edge'] # Shared tokenizer
|
273 |
+
if ('tok_semseg@224' in self.mods or 'tok_semseg@448' in self.mods_sr) and tok_semseg is not None:
|
274 |
+
self.toks['tok_semseg'] = load_model(tok_semseg, VQVAE)
|
275 |
+
if ('tok_clip@224' in self.mods or 'tok_clip@448' in self.mods_sr) and tok_clip is not None:
|
276 |
+
self.toks['tok_clip'] = load_model(tok_clip, VQVAE)
|
277 |
+
if 'tok_dinov2@224' in self.mods and tok_dinov2 is not None:
|
278 |
+
self.toks['tok_dinov2'] = load_model(tok_dinov2, VQVAE)
|
279 |
+
if 'tok_imagebind@224' in self.mods and tok_imagebind is not None:
|
280 |
+
self.toks['tok_imagebind'] = load_model(tok_imagebind, VQVAE)
|
281 |
+
if 'sam_instance' in self.mods and tok_sam_instance is not None:
|
282 |
+
self.toks['sam_instance'] = load_model(tok_sam_instance, VQVAE)
|
283 |
+
if 'human_poses' in self.mods and tok_human_poses is not None:
|
284 |
+
self.toks['human_poses'] = load_model(tok_human_poses, VQVAE)
|
285 |
+
self.toks = nn.ModuleDict(self.toks)
|
286 |
+
self.tok_text = Tokenizer.from_file(tok_text)
|
287 |
+
|
288 |
+
if self.verbose:
|
289 |
+
print(' done!')
|
290 |
+
|
291 |
+
@property
|
292 |
+
def device(self):
|
293 |
+
return next(self.parameters()).device
|
294 |
+
|
295 |
+
def __setup_conds_and_targets(self, sample):
|
296 |
+
# Input and output modalities
|
297 |
+
cond_domains = [domain for domain in list(sample.keys()) if domain in self.mods]
|
298 |
+
target_domains = [domain for domain in DEFAULT_ORDER if (domain not in cond_domains and domain in self.mods)]
|
299 |
+
if 'rgb@224' in cond_domains:
|
300 |
+
# Do not generate tokenized RGB if pixel RGB is given as input
|
301 |
+
target_domains.remove('tok_rgb@224')
|
302 |
+
return cond_domains, target_domains
|
303 |
+
|
304 |
+
def __setup_sr_conds_and_targets(self, sample):
|
305 |
+
cond_domains_sr = [domain for domain in list(sample.keys()) if domain in self.mods_sr]
|
306 |
+
target_domains_sr = [domain for domain in DEFAULT_ORDER_SR if (domain.replace('448', '224') in cond_domains_sr and domain in self.mods_sr)]
|
307 |
+
return cond_domains_sr, target_domains_sr
|
308 |
+
|
309 |
+
def __setup_sample_and_schedule(self, sample, cond_domains, target_domains, cfg_grow_conditioning=True):
|
310 |
+
# 1 - Setup generation schedule
|
311 |
+
|
312 |
+
defaults = DEFAULTS_RGB2X if ('rgb@224' in cond_domains or 'tok_rgb@224' in cond_domains) else DEFAULTS_X2RGB
|
313 |
+
|
314 |
+
tokens_per_target = [get_value(defaults, domain, 'tokens_per_target') for domain in target_domains]
|
315 |
+
autoregression_schemes = [get_value(defaults, domain, 'autoregression_scheme') for domain in target_domains]
|
316 |
+
decoding_steps = [get_value(defaults, domain, 'decoding_steps') for domain in target_domains]
|
317 |
+
token_decoding_schedules = [get_value(defaults, domain, 'token_decoding_schedule') for domain in target_domains]
|
318 |
+
temps = [get_value(defaults, domain, 'temp') for domain in target_domains]
|
319 |
+
temp_schedules = [get_value(defaults, domain, 'temp_schedule') for domain in target_domains]
|
320 |
+
cfg_scales = [get_value(defaults, domain, 'cfg_scale') for domain in target_domains]
|
321 |
+
cfg_schedules = [get_value(defaults, domain, 'cfg_schedule') for domain in target_domains]
|
322 |
+
|
323 |
+
schedule = build_chained_generation_schedules(
|
324 |
+
cond_domains=cond_domains, target_domains=target_domains, tokens_per_target=tokens_per_target,
|
325 |
+
autoregression_schemes=autoregression_schemes, decoding_steps=decoding_steps,
|
326 |
+
token_decoding_schedules=token_decoding_schedules, temps=temps, temp_schedules=temp_schedules,
|
327 |
+
cfg_scales=cfg_scales, cfg_schedules=cfg_schedules, cfg_grow_conditioning=cfg_grow_conditioning,
|
328 |
+
)
|
329 |
+
|
330 |
+
# 2 - Setup sample
|
331 |
+
|
332 |
+
sample_dict = {}
|
333 |
+
|
334 |
+
# Handle special cases
|
335 |
+
if 'caption' in sample:
|
336 |
+
caption = sample.pop('caption')
|
337 |
+
sample_dict = custom_text(
|
338 |
+
sample_dict, input_text=caption, eos_token='[EOS]',
|
339 |
+
key='caption', device=self.device, text_tokenizer=self.tok_text
|
340 |
+
)
|
341 |
+
if 'det' in sample:
|
342 |
+
caption = sample.pop('det')
|
343 |
+
sample_dict = custom_text(
|
344 |
+
sample_dict, input_text=caption, eos_token='[EOS]',
|
345 |
+
key='det', device=self.device, text_tokenizer=self.tok_text
|
346 |
+
)
|
347 |
+
# Add remaining modalities
|
348 |
+
sample_dict.update({domain: {'tensor': tensor} for domain, tensor in sample.items()})
|
349 |
+
|
350 |
+
# Initialize these remaining input modalities (caption and det are already initialized by custom_text)
|
351 |
+
for cond_mod in sample.keys():
|
352 |
+
sample_dict = init_full_input_modality(sample_dict, MODALITY_INFO, cond_mod, self.device, eos_id=self.tok_text.token_to_id("[EOS]"))
|
353 |
+
|
354 |
+
# Initialize target modalities
|
355 |
+
for target_mod, ntoks in zip(target_domains, tokens_per_target):
|
356 |
+
sample_dict = init_empty_target_modality(sample_dict, MODALITY_INFO, target_mod, 1, ntoks, self.device)
|
357 |
+
|
358 |
+
return sample_dict, schedule
|
359 |
+
|
360 |
+
def __setup_sr_sample_and_schedule(self, out_dict, cond_domains_sr, target_domains_sr, cfg_grow_conditioning_sr=True):
|
361 |
+
# 1 - Setup generation schedule
|
362 |
+
|
363 |
+
tokens_per_target_sr = [get_value(DEFAULTS_SR, domain, 'tokens_per_target') for domain in target_domains_sr]
|
364 |
+
autoregression_schemes_sr = [get_value(DEFAULTS_SR, domain, 'autoregression_scheme') for domain in target_domains_sr]
|
365 |
+
decoding_steps_sr = [get_value(DEFAULTS_SR, domain, 'decoding_steps') for domain in target_domains_sr]
|
366 |
+
token_decoding_schedules_sr = [get_value(DEFAULTS_SR, domain, 'token_decoding_schedule') for domain in target_domains_sr]
|
367 |
+
temps_sr = [get_value(DEFAULTS_SR, domain, 'temp') for domain in target_domains_sr]
|
368 |
+
temp_schedules_sr = [get_value(DEFAULTS_SR, domain, 'temp_schedule') for domain in target_domains_sr]
|
369 |
+
cfg_scales_sr = [get_value(DEFAULTS_SR, domain, 'cfg_scale') for domain in target_domains_sr]
|
370 |
+
cfg_schedules_sr = [get_value(DEFAULTS_SR, domain, 'cfg_schedule') for domain in target_domains_sr]
|
371 |
+
|
372 |
+
schedule_sr = build_chained_generation_schedules(
|
373 |
+
cond_domains=cond_domains_sr, target_domains=target_domains_sr, tokens_per_target=tokens_per_target_sr,
|
374 |
+
autoregression_schemes=autoregression_schemes_sr, decoding_steps=decoding_steps_sr,
|
375 |
+
token_decoding_schedules=token_decoding_schedules_sr, temps=temps_sr, temp_schedules=temp_schedules_sr,
|
376 |
+
cfg_scales=cfg_scales_sr, cfg_schedules=cfg_schedules_sr, cfg_grow_conditioning=cfg_grow_conditioning_sr,
|
377 |
+
)
|
378 |
+
|
379 |
+
# 2 - Setup sample
|
380 |
+
|
381 |
+
sample_sr = out_dict
|
382 |
+
|
383 |
+
# Handle case where generated caption or bounding boxes is just [EOS]
|
384 |
+
if 'caption' in sample_sr and sample_sr['caption']['tensor'].shape[1] <= 1 and 'caption' in cond_domains_sr:
|
385 |
+
sample_sr = custom_text(
|
386 |
+
sample_sr, input_text='[S_1]', eos_token='[EOS]',
|
387 |
+
key='caption', device=self.device, text_tokenizer=self.tok_text
|
388 |
+
)
|
389 |
+
if 'det' in sample_sr and sample_sr['det']['tensor'].shape[1] <= 1 and 'det' in cond_domains_sr:
|
390 |
+
sample_sr = custom_text(
|
391 |
+
sample_sr, input_text='[S_1]', eos_token='[EOS]',
|
392 |
+
key='det', device=self.device, text_tokenizer=self.tok_text
|
393 |
+
)
|
394 |
+
|
395 |
+
# Initialize input modalities
|
396 |
+
for cond_mod in cond_domains_sr:
|
397 |
+
sample_sr = init_full_input_modality(sample_sr, MODALITY_INFO, cond_mod, self.device, eos_id=self.tok_text.token_to_id("[EOS]"))
|
398 |
+
|
399 |
+
# Initialize target modalities
|
400 |
+
for target_mod, ntoks in zip(target_domains_sr, tokens_per_target_sr):
|
401 |
+
sample_sr = init_empty_target_modality(sample_sr, MODALITY_INFO, target_mod, 1, ntoks, self.device)
|
402 |
+
|
403 |
+
return sample_sr, schedule_sr
|
404 |
+
|
405 |
+
def forward(self, sample, seed: Optional[int] = None, top_p: float = 0.8, top_k: float = 0.0, target_modalities: Optional[List[str]] = None, perform_sr: bool = True):
|
406 |
+
seed = seed or np.random.randint(np.iinfo(np.int64).max)
|
407 |
+
|
408 |
+
# Prepare the generation parameters and sample
|
409 |
+
cond_domains, target_domains = self.__setup_conds_and_targets(sample)
|
410 |
+
target_domains = target_modalities or target_domains
|
411 |
+
sample, generation_schedule = self.__setup_sample_and_schedule(sample, cond_domains, target_domains)
|
412 |
+
|
413 |
+
# Generation and decoding at the base resolution 224x224
|
414 |
+
if self.verbose:
|
415 |
+
print(f'Generating {cond_domains} -> {target_domains} ...')
|
416 |
+
out_dict = self.sampler_fm.generate(
|
417 |
+
sample, generation_schedule, text_tokenizer=self.tok_text,
|
418 |
+
verbose=self.verbose, seed=seed, top_p=top_p, top_k=top_k,
|
419 |
+
)
|
420 |
+
dec_dict = decode_dict(
|
421 |
+
out_dict, self.toks, self.tok_text, image_size=224,
|
422 |
+
patch_size=16, decoding_steps=50
|
423 |
+
)
|
424 |
+
|
425 |
+
# Optional upsampling to 448x448
|
426 |
+
if self.sampler_fm_sr is not None and perform_sr:
|
427 |
+
cond_domains_sr, target_domains_sr = self.__setup_sr_conds_and_targets(out_dict)
|
428 |
+
sample_sr, generation_schedule_sr = self.__setup_sr_sample_and_schedule(out_dict, cond_domains_sr, target_domains_sr)
|
429 |
+
|
430 |
+
if self.verbose:
|
431 |
+
print(f'Super-resolving {target_domains_sr} ...')
|
432 |
+
out_dict_sr = self.sampler_fm_sr.generate(
|
433 |
+
sample_sr, generation_schedule_sr, text_tokenizer=self.tok_text,
|
434 |
+
verbose=self.verbose, seed=seed+1, top_p=top_p, top_k=top_k,
|
435 |
+
)
|
436 |
+
dec_dict = decode_dict(
|
437 |
+
out_dict_sr, self.toks, self.tok_text, image_size=448,
|
438 |
+
patch_size=16, decoding_steps=50
|
439 |
+
)
|
440 |
+
|
441 |
+
# Remove padding tokens
|
442 |
+
if 'caption' in dec_dict:
|
443 |
+
dec_dict['caption'][0].replace('[PAD]', '').strip()
|
444 |
+
if 'det' in dec_dict:
|
445 |
+
dec_dict['det'][0].replace('[PAD]', '').strip()
|
446 |
+
|
447 |
+
return dec_dict
|
448 |
+
|
449 |
+
def plot_modalities(self, mod_dict, ncols_max=5, figscale=4.0, save_path=None, use_fixed_plotting_order=False):
|
450 |
+
nmods = len(mod_dict)
|
451 |
+
ncols = min(nmods, ncols_max)
|
452 |
+
nrows = math.ceil(nmods / ncols)
|
453 |
+
|
454 |
+
fig, ax = plt.subplots(
|
455 |
+
nrows=nrows, ncols=ncols,
|
456 |
+
figsize=(ncols*figscale, nrows*figscale),
|
457 |
+
facecolor=(1, 1, 1)
|
458 |
+
)
|
459 |
+
|
460 |
+
if use_fixed_plotting_order:
|
461 |
+
mod_dict = {
|
462 |
+
k: mod_dict[k] for k in MODALITY_PLOTTING_ORDER
|
463 |
+
if k in mod_dict
|
464 |
+
}
|
465 |
+
|
466 |
+
for i, (mod_name, mod) in enumerate(mod_dict.items()):
|
467 |
+
if nrows == 1:
|
468 |
+
ax_i = ax[i]
|
469 |
+
else:
|
470 |
+
row, col = i // ncols, i % ncols
|
471 |
+
ax_i = ax[row,col]
|
472 |
+
|
473 |
+
if mod_name == 'det':
|
474 |
+
# Attempt to get the first available value from mod_dict according to the priority
|
475 |
+
keys_in_order = ['rgb@448', 'rgb@224', 'tok_rgb@448', 'tok_rgb@224']
|
476 |
+
rgb_background = next((mod_dict[key] for key in keys_in_order if key in mod_dict), np.ones((224, 224, 3)))
|
477 |
+
rgb_background = (255 * rgb_background).astype(np.uint8)
|
478 |
+
ax_i.imshow(visualize_bboxes(rgb_background, mod[0],).astype(np.uint8))
|
479 |
+
elif mod_name == 'caption':
|
480 |
+
plot_text_in_square(ax_i, mod[0], wrap_width=16, fontsize=14)
|
481 |
+
elif mod_name == 'metadata':
|
482 |
+
metadata_pred = ',\n'.join([f'{k}: {v:.2f}' if isinstance(v, float) else f'{k}: {v}' for k, v in mod.items()])
|
483 |
+
plot_text_in_square(ax_i, metadata_pred, wrap_width=36, fontsize=13)
|
484 |
+
else:
|
485 |
+
ax_i.imshow(mod)
|
486 |
+
|
487 |
+
ax_i.set_title(MODALITY_PLOTTING_NAME_MAP.get(mod_name, mod_name), fontsize=18)
|
488 |
+
|
489 |
+
for i, axis in enumerate(ax.flatten()):
|
490 |
+
axis.set_xticks([])
|
491 |
+
axis.set_yticks([])
|
492 |
+
if i >= len(mod_dict):
|
493 |
+
axis.spines['top'].set_visible(False)
|
494 |
+
axis.spines['right'].set_visible(False)
|
495 |
+
axis.spines['bottom'].set_visible(False)
|
496 |
+
axis.spines['left'].set_visible(False)
|
497 |
+
|
498 |
+
plt.tight_layout()
|
499 |
+
if save_path is not None:
|
500 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
501 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=300)
|
502 |
+
plt.close()
|
503 |
+
else:
|
504 |
+
plt.show()
|
505 |
+
|
506 |
+
def modalities_to_pil(self, mod_dict, use_fixed_plotting_order=False, resize=None):
|
507 |
+
if use_fixed_plotting_order:
|
508 |
+
mod_dict = {
|
509 |
+
k: mod_dict[k] for k in MODALITY_PLOTTING_ORDER
|
510 |
+
if k in mod_dict
|
511 |
+
}
|
512 |
+
|
513 |
+
plotted_modalities = []
|
514 |
+
|
515 |
+
for i, (mod_name, mod) in enumerate(mod_dict.items()):
|
516 |
+
if mod_name == 'det':
|
517 |
+
# Attempt to get the first available value from mod_dict according to the priority
|
518 |
+
keys_in_order = ['rgb@448', 'rgb@224', 'tok_rgb@448', 'tok_rgb@224']
|
519 |
+
rgb_background = next((mod_dict[key] for key in keys_in_order if key in mod_dict), np.ones((224, 224, 3)))
|
520 |
+
rgb_background = (255 * rgb_background).astype(np.uint8)
|
521 |
+
img_pil = Image.fromarray(visualize_bboxes(rgb_background, mod[0],).astype(np.uint8))
|
522 |
+
elif mod_name == 'caption':
|
523 |
+
img_pil = text_to_pil_image(mod[0][:512], wrap_width=40, fontsize=14)
|
524 |
+
elif mod_name == 'metadata':
|
525 |
+
metadata_pred = ',\n'.join([f'{k}: {v:.2f}' if isinstance(v, float) else f'{k}: {v}' for k, v in mod.items()])
|
526 |
+
img_pil = text_to_pil_image(metadata_pred, wrap_width=36, fontsize=13)
|
527 |
+
else:
|
528 |
+
img_pil = Image.fromarray((255*mod).astype(np.uint8))
|
529 |
+
|
530 |
+
if resize is not None:
|
531 |
+
if mod_name in ['tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224', 'tok_clip@448']:
|
532 |
+
resample_mode = Image.Resampling.NEAREST
|
533 |
+
else:
|
534 |
+
resample_mode = Image.Resampling.BILINEAR
|
535 |
+
img_pil = img_pil.resize((resize, resize), resample=resample_mode)
|
536 |
+
|
537 |
+
plot_name = MODALITY_PLOTTING_NAME_MAP.get(mod_name, mod_name)
|
538 |
+
plotted_modalities.append((img_pil, plot_name))
|
539 |
+
|
540 |
+
return plotted_modalities
|
fourm/models/__init__.py
ADDED
File without changes
|
fourm/models/decoder_embeddings.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
from typing import Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from einops import repeat
|
20 |
+
|
21 |
+
from .fm_utils import build_1d_sincos_posemb, build_2d_sincos_posemb, pair
|
22 |
+
|
23 |
+
|
24 |
+
class SequenceDecoderEmbedding(nn.Module):
|
25 |
+
"""Embedding module for sequence inputs, like captions or a sequence of objects.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
vocab_size: Vocabulary size
|
29 |
+
max_length: Maximum number of tokens in the sequence
|
30 |
+
dim_tokens: Dimension of output tokens. Can be set using init method.
|
31 |
+
sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings
|
32 |
+
padding_idx: Padding index for word embedding
|
33 |
+
share_embedding: Set to True to share input and output embedding weights
|
34 |
+
"""
|
35 |
+
def __init__(self,
|
36 |
+
vocab_size: int,
|
37 |
+
max_length: int,
|
38 |
+
dim_tokens: Optional[int] = None,
|
39 |
+
sincos_pos_emb: bool = True,
|
40 |
+
max_sincos_pos_emb: int = 512,
|
41 |
+
padding_idx: int = 0,
|
42 |
+
share_embedding: bool = True,
|
43 |
+
**kwargs):
|
44 |
+
super().__init__()
|
45 |
+
self.vocab_size = vocab_size
|
46 |
+
self.max_length = max_length
|
47 |
+
self.dim_tokens = dim_tokens
|
48 |
+
self.sincos_pos_emb = sincos_pos_emb
|
49 |
+
self.padding_idx = padding_idx
|
50 |
+
self.max_sincos_pos_emb = max_sincos_pos_emb
|
51 |
+
self.share_embedding = share_embedding
|
52 |
+
|
53 |
+
if self.dim_tokens is not None:
|
54 |
+
self.init(dim_tokens=dim_tokens)
|
55 |
+
|
56 |
+
def init(self, dim_tokens: int = 768, init_std=0.02):
|
57 |
+
"""
|
58 |
+
Initialize parts of embedding module that are dependent on dimension of tokens.
|
59 |
+
Should be called when setting up FourM.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
dim_tokens: Dimension of tokens
|
63 |
+
init_std: Standard deviation of init
|
64 |
+
"""
|
65 |
+
self.dim_tokens = dim_tokens
|
66 |
+
|
67 |
+
# Task embedding identifying from which task a given token comes from
|
68 |
+
# Fixed-size positional embeddings. Can be interpolated to different input sizes
|
69 |
+
|
70 |
+
if self.sincos_pos_emb:
|
71 |
+
if self.max_length > self.max_sincos_pos_emb:
|
72 |
+
raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}")
|
73 |
+
# Get all posembs, than truncate up to max length
|
74 |
+
pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length]
|
75 |
+
self.register_buffer("pos_emb", pos_emb)
|
76 |
+
else:
|
77 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens))
|
78 |
+
nn.init.normal_(self.pos_emb, std=init_std)
|
79 |
+
|
80 |
+
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
|
81 |
+
nn.init.normal_(self.mod_emb, std=init_std)
|
82 |
+
|
83 |
+
# Token embedding
|
84 |
+
self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens, padding_idx=self.padding_idx)
|
85 |
+
|
86 |
+
# Output projection layer
|
87 |
+
self.to_logits = nn.Linear(self.dim_tokens, self.vocab_size, bias=False)
|
88 |
+
|
89 |
+
if self.share_embedding:
|
90 |
+
# Share input and output embedding weights
|
91 |
+
self.to_logits.weight = self.token_emb.weight
|
92 |
+
|
93 |
+
|
94 |
+
@torch.jit.ignore
|
95 |
+
def no_weight_decay(self):
|
96 |
+
return set()
|
97 |
+
|
98 |
+
def forward_embed(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
99 |
+
"""
|
100 |
+
Forward pass through embedding module, transforming sequence of ids to sequence of embeddings.
|
101 |
+
Creates corresponding modality and positional embeddings and adds them to the dict.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
d (Dict[str, torch.Tensor]): Modality dict, with at least the following keys:
|
105 |
+
- 'tensor' (torch.Tensor): Token sequence for each batch. Shape (B, L) where B is the batch size and L is the sequence length.
|
106 |
+
- 'target_mask' (torch.Tensor): Mask for valid tokens in the target sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Dict[str, torch.Tensor]: Modality dict with added keys:
|
110 |
+
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the embedding dimension.
|
111 |
+
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the target sequence. Shape (B, L, D).
|
112 |
+
- 'ids' (torch.Tensor): Original token sequence from input dict. Shape (B, L).
|
113 |
+
"""
|
114 |
+
ids = d['tensor']
|
115 |
+
B = ids.shape[0]
|
116 |
+
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
|
117 |
+
|
118 |
+
# Map to embedding
|
119 |
+
x = self.token_emb(ids)
|
120 |
+
|
121 |
+
expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B)
|
122 |
+
|
123 |
+
# Target pos encoding
|
124 |
+
target_mask = d['target_mask']
|
125 |
+
target_pos_id = (~target_mask).int().cumsum(dim=1) - 1
|
126 |
+
target_pos_id[target_mask] = 0
|
127 |
+
# Sometimes target sequence is over max length, it will be truncated in decoder
|
128 |
+
target_pos_id[target_pos_id >= self.max_length] = 0
|
129 |
+
target_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(target_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2]))
|
130 |
+
target_pos_emb[target_mask] = 0
|
131 |
+
|
132 |
+
x_emb = target_pos_emb + self.mod_emb
|
133 |
+
|
134 |
+
|
135 |
+
d['x'] = x
|
136 |
+
d['emb'] = x_emb
|
137 |
+
d['ids'] = d['tensor']
|
138 |
+
|
139 |
+
return d
|
140 |
+
|
141 |
+
def forward_logits(self, x: torch.Tensor) -> torch.Tensor:
|
142 |
+
"""
|
143 |
+
Forward pass through output projection layer, transforming sequence of embeddings to logits.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
x (torch.Tensor): Output tokens from the decoder. Shape (B, M, D)
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
torch.Tensor: Logits for each token in the sequence. Shape (B, M, V)
|
150 |
+
"""
|
151 |
+
logits = self.to_logits(x)
|
152 |
+
return logits
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
class ImageTokenDecoderEmbedding(nn.Module):
|
157 |
+
"""Embedding module for tokenized spatial inputs.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
vocab_size: Vocabulary size
|
161 |
+
patch_size: Int or tuple of the patch size over the full image size.
|
162 |
+
dim_tokens: Dimension of output tokens. Can be set using init method.
|
163 |
+
sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
|
164 |
+
image_size: Default image size. Used to initialize size of positional embeddings.
|
165 |
+
share_embedding: Set to True to share input and output embedding weights
|
166 |
+
"""
|
167 |
+
def __init__(self,
|
168 |
+
vocab_size: int,
|
169 |
+
patch_size: Union[int, Tuple[int,int]] = 16,
|
170 |
+
dim_tokens: Optional[int] = None,
|
171 |
+
sincos_pos_emb: bool = True,
|
172 |
+
image_size: Union[int, Tuple[int]] = 224,
|
173 |
+
share_embedding: bool = True,
|
174 |
+
**kwargs):
|
175 |
+
super().__init__()
|
176 |
+
self.vocab_size = vocab_size
|
177 |
+
self.patch_size = pair(patch_size)
|
178 |
+
self.dim_tokens = dim_tokens
|
179 |
+
self.sincos_pos_emb = sincos_pos_emb
|
180 |
+
self.image_size = pair(image_size)
|
181 |
+
self.num_patches = (self.image_size[0] // self.patch_size[0]) * (self.image_size[1] // self.patch_size[1])
|
182 |
+
self.share_embedding = share_embedding
|
183 |
+
|
184 |
+
if self.dim_tokens is not None:
|
185 |
+
self.init(dim_tokens=dim_tokens)
|
186 |
+
|
187 |
+
def init(self, dim_tokens: int = 768, init_std=0.02):
|
188 |
+
"""
|
189 |
+
Initialize parts of module that are dependent on dimension of tokens.
|
190 |
+
Should be called when setting up FourM.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
dim_tokens: Dimension of tokens
|
194 |
+
init_std: Standard deviation of init
|
195 |
+
"""
|
196 |
+
self.dim_tokens = dim_tokens
|
197 |
+
|
198 |
+
# Task embedding identifying from which task a given token comes from
|
199 |
+
# Fixed-size positional embeddings. Can be interpolated to different input sizes
|
200 |
+
h_posemb = self.image_size[0] // self.patch_size[0]
|
201 |
+
w_posemb = self.image_size[1] // self.patch_size[1]
|
202 |
+
if self.sincos_pos_emb:
|
203 |
+
pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens)
|
204 |
+
self.register_buffer("pos_emb", pos_emb)
|
205 |
+
else:
|
206 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens))
|
207 |
+
nn.init.normal_(self.pos_emb, std=init_std)
|
208 |
+
|
209 |
+
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
|
210 |
+
nn.init.normal_(self.mod_emb, std=init_std)
|
211 |
+
|
212 |
+
# Token embedding (not needed if only masked tokens are given as input, but can be useful to train Token Critic)
|
213 |
+
self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens)
|
214 |
+
|
215 |
+
# Output projection layer
|
216 |
+
self.to_logits = nn.Linear(self.dim_tokens, self.vocab_size, bias=False)
|
217 |
+
|
218 |
+
if self.share_embedding:
|
219 |
+
# Share input and output embedding weights
|
220 |
+
self.to_logits.weight = self.token_emb.weight
|
221 |
+
|
222 |
+
@torch.jit.ignore
|
223 |
+
def no_weight_decay(self):
|
224 |
+
return set()
|
225 |
+
|
226 |
+
def forward_embed(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
227 |
+
"""
|
228 |
+
Forward pass through the embedding module, transforming tokenized spatial inputs to embeddings.
|
229 |
+
Creates corresponding modality and positional embeddings and adds them to the dict.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
d (Dict[str, torch.Tensor]): Modality dict, with at least the following key:
|
233 |
+
- 'tensor' (torch.Tensor): Modality tokens for each batch (e.g. from tokenized images). Shape (B, H, W) where B is the batch size, H and W are height and width after tokenization.
|
234 |
+
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
Dict[str, torch.Tensor]: Modality dict with added keys:
|
238 |
+
- 'x' (torch.Tensor): Embedded token sequence, which is replaced by mask tokens in the 4M decoder. Shape (B, H*W, D) where D is the embedding dimension.
|
239 |
+
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the token sequence. Shape (B, H*W, D).
|
240 |
+
- 'ids' (torch.Tensor): Reshaped token sequence from input dict, flattened in the spatial dimensions. Shape (B, H*W).
|
241 |
+
"""
|
242 |
+
ids = d['tensor']
|
243 |
+
B = ids.shape[0]
|
244 |
+
ids = ids.reshape(B, -1)
|
245 |
+
|
246 |
+
# Map to embedding
|
247 |
+
x = self.token_emb(ids)
|
248 |
+
|
249 |
+
# Create positional embedding + modality embedding
|
250 |
+
x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B)
|
251 |
+
|
252 |
+
d['x'] = x
|
253 |
+
d['emb'] = x_emb
|
254 |
+
d['ids'] = ids
|
255 |
+
return d
|
256 |
+
|
257 |
+
def forward_logits(self, x: torch.Tensor) -> torch.Tensor:
|
258 |
+
"""
|
259 |
+
Forward pass through output projection layer, transforming sequence of embeddings to logits.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
x (torch.Tensor): Output tokens from the decoder. Shape (B, M, D)
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
torch.Tensor: Logits for each token in the sequence. Shape (B, M, V)
|
266 |
+
"""
|
267 |
+
logits = self.to_logits(x)
|
268 |
+
return logits
|
fourm/models/encoder_embeddings.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Dict, List, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from einops import rearrange, repeat
|
19 |
+
|
20 |
+
from .fm_utils import build_1d_sincos_posemb, build_2d_sincos_posemb, pair
|
21 |
+
|
22 |
+
class SequenceEncoderEmbedding(nn.Module):
|
23 |
+
"""Embedding module for encoding sequence inputs, like captions or a sequence of objects.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
vocab_size: Vocabulary size
|
27 |
+
max_length: Maximum number of tokens in the sequence
|
28 |
+
dim_tokens: Dimension of output tokens. Can be set using init method.
|
29 |
+
sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings
|
30 |
+
max_sincos_pos_emb: Maximum allowed length for sin-cos positional embeddings
|
31 |
+
padding_idx: Padding index for word embedding
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
vocab_size: int,
|
36 |
+
max_length: int,
|
37 |
+
dim_tokens: Optional[int] = None,
|
38 |
+
sincos_pos_emb: bool = True,
|
39 |
+
max_sincos_pos_emb: int = 512,
|
40 |
+
padding_idx: int = 0,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
self.vocab_size = vocab_size
|
44 |
+
self.max_length = max_length
|
45 |
+
self.dim_tokens = dim_tokens
|
46 |
+
self.sincos_pos_emb = sincos_pos_emb
|
47 |
+
self.padding_idx = padding_idx
|
48 |
+
self.max_sincos_pos_emb = max_sincos_pos_emb
|
49 |
+
|
50 |
+
if self.dim_tokens is not None:
|
51 |
+
self.init(dim_tokens=dim_tokens)
|
52 |
+
|
53 |
+
def init(self, dim_tokens: int = 768, init_std=0.02):
|
54 |
+
"""
|
55 |
+
Initialize parts of embedding module that are dependent on dimension of tokens.
|
56 |
+
Should be called when setting up FourM.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
dim_tokens: Dimension of tokens
|
60 |
+
init_std: Standard deviation of init
|
61 |
+
"""
|
62 |
+
self.dim_tokens = dim_tokens
|
63 |
+
|
64 |
+
# Task embedding identifying from which task a given token comes from
|
65 |
+
# Fixed-size positional embeddings. Can be interpolated to different input sizes
|
66 |
+
if self.sincos_pos_emb:
|
67 |
+
if self.max_length > self.max_sincos_pos_emb:
|
68 |
+
raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}")
|
69 |
+
pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length]
|
70 |
+
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
|
71 |
+
else:
|
72 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens))
|
73 |
+
nn.init.normal_(self.pos_emb, std=init_std)
|
74 |
+
|
75 |
+
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
|
76 |
+
nn.init.normal_(self.mod_emb, std=init_std)
|
77 |
+
|
78 |
+
# Token embedding
|
79 |
+
self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens,
|
80 |
+
padding_idx=self.padding_idx)
|
81 |
+
|
82 |
+
|
83 |
+
@torch.jit.ignore
|
84 |
+
def no_weight_decay(self):
|
85 |
+
return set()
|
86 |
+
|
87 |
+
def forward(self, d : Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
88 |
+
"""
|
89 |
+
Forward pass through embedding module, transforming sequence of ids to sequence of embeddings.
|
90 |
+
Creates corresponding modality and positional embeddings and adds them to the dict.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
d (Dict[str, torch.Tensor]): Modality dict with at least the following keys:
|
94 |
+
- 'tensor' (torch.Tensor): Input token sequence for each batch. Shape (B, L) where B is the batch size and L is the sequence length.
|
95 |
+
- 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Dict[str, torch.Tensor]: Modality dict with added keys:
|
99 |
+
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the embedding dimension.
|
100 |
+
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, L, D).
|
101 |
+
"""
|
102 |
+
ids = d['tensor']
|
103 |
+
B = ids.shape[0]
|
104 |
+
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
|
105 |
+
|
106 |
+
# Map to embedding
|
107 |
+
x = self.token_emb(ids)
|
108 |
+
|
109 |
+
expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B)
|
110 |
+
# Input pos encoding
|
111 |
+
input_mask = d['input_mask']
|
112 |
+
input_pos_id = (~input_mask).int().cumsum(dim=1) - 1
|
113 |
+
input_pos_id[input_mask] = 0
|
114 |
+
input_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(input_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2]))
|
115 |
+
input_pos_emb[input_mask] = 0
|
116 |
+
|
117 |
+
x_emb = input_pos_emb + self.mod_emb
|
118 |
+
|
119 |
+
d['x'] = x
|
120 |
+
d['emb'] = x_emb
|
121 |
+
return d
|
122 |
+
|
123 |
+
class ImageTokenEncoderEmbedding(nn.Module):
|
124 |
+
"""Embedding module for tokenized spatial inputs.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
vocab_size: Vocabulary size
|
128 |
+
patch_size: Int or tuple of the patch size over the full image size.
|
129 |
+
dim_tokens: Dimension of output tokens. Can be set using init method.
|
130 |
+
sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
|
131 |
+
image_size: Default image size. Used to initialize size of positional embeddings.
|
132 |
+
"""
|
133 |
+
def __init__(self,
|
134 |
+
vocab_size: int,
|
135 |
+
patch_size: Union[int, Tuple[int,int]] = 16,
|
136 |
+
dim_tokens: Optional[int] = None,
|
137 |
+
sincos_pos_emb: bool = True,
|
138 |
+
image_size: Union[int, Tuple[int]] = 224,
|
139 |
+
**kwargs):
|
140 |
+
|
141 |
+
super().__init__()
|
142 |
+
self.vocab_size = vocab_size
|
143 |
+
self.patch_size = pair(patch_size)
|
144 |
+
self.dim_tokens = dim_tokens
|
145 |
+
self.sincos_pos_emb = sincos_pos_emb
|
146 |
+
self.image_size = pair(image_size)
|
147 |
+
self.num_patches = (self.image_size[0] // patch_size) * (self.image_size[1] // patch_size)
|
148 |
+
|
149 |
+
if self.dim_tokens is not None:
|
150 |
+
self.init(dim_tokens=dim_tokens)
|
151 |
+
|
152 |
+
def init(self, dim_tokens: int = 768, init_std=0.02):
|
153 |
+
"""
|
154 |
+
Initialize parts of module that are dependent on dimension of tokens.
|
155 |
+
Should be called when setting up FourM.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
dim_tokens: Dimension of tokens
|
159 |
+
init_std: Standard deviation of init
|
160 |
+
"""
|
161 |
+
self.dim_tokens = dim_tokens
|
162 |
+
|
163 |
+
# Task embedding identifying from which task a given token comes from
|
164 |
+
# Fixed-size positional embeddings. Can be interpolated to different input sizes
|
165 |
+
h_posemb = self.image_size[0] // self.patch_size[0]
|
166 |
+
w_posemb = self.image_size[1] // self.patch_size[1]
|
167 |
+
if self.sincos_pos_emb:
|
168 |
+
pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens)
|
169 |
+
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
|
170 |
+
else:
|
171 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens))
|
172 |
+
nn.init.normal_(self.pos_emb, std=init_std)
|
173 |
+
|
174 |
+
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
|
175 |
+
nn.init.normal_(self.mod_emb, std=init_std)
|
176 |
+
|
177 |
+
# Token embedding
|
178 |
+
self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens)
|
179 |
+
|
180 |
+
@torch.jit.ignore
|
181 |
+
def no_weight_decay(self):
|
182 |
+
return set()
|
183 |
+
|
184 |
+
def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
185 |
+
"""
|
186 |
+
Forward pass through embedding module, transforming image tokens to a sequence of embeddings.
|
187 |
+
Creates corresponding modality and positional embeddings and adds them to the dict.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
d (Dict[str, torch.Tensor]): Modality dict with at least the following key:
|
191 |
+
- 'tensor' (torch.Tensor): Input image tokens for each batch. Shape (B, H, W) where B is the batch size, and H, W are height and width of the tokenized image. - 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Dict[str, torch.Tensor]: Modality dictionary with added keys:
|
195 |
+
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, H*W, D).
|
196 |
+
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, H*W, D).
|
197 |
+
"""
|
198 |
+
ids = d['tensor']
|
199 |
+
B = ids.shape[0]
|
200 |
+
ids = ids.reshape(B, -1)
|
201 |
+
|
202 |
+
# Map to embedding
|
203 |
+
x = self.token_emb(ids)
|
204 |
+
|
205 |
+
# Create positional embedding + modality embedding
|
206 |
+
x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B)
|
207 |
+
|
208 |
+
d['x'] = x
|
209 |
+
d['emb'] = x_emb
|
210 |
+
|
211 |
+
return d
|
212 |
+
|
213 |
+
|
214 |
+
class ImageEncoderEmbedding(nn.Module):
|
215 |
+
"""Embedding module for spatial inputs, like images or feature maps.
|
216 |
+
Creates tokens from patches over the image.
|
217 |
+
|
218 |
+
This adapter / embedding differs from the one of MultiMAE by taking as input a dict and
|
219 |
+
separating positional embeddings and modality embeddings from the input projection
|
220 |
+
Input projection is 'x', posemb + modemb is 'emb'
|
221 |
+
|
222 |
+
Args:
|
223 |
+
num_channels: Number of input channels of the image/feature map
|
224 |
+
patch_size: Int or tuple of the patch size over the full image size.
|
225 |
+
dim_tokens: Dimension of output tokens. Can be set using init method.
|
226 |
+
sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
|
227 |
+
image_size: Default image size. Used to initialize size of positional embeddings.
|
228 |
+
"""
|
229 |
+
def __init__(self,
|
230 |
+
num_channels: int,
|
231 |
+
patch_size: Union[int, Tuple[int,int]],
|
232 |
+
dim_tokens: Optional[int] = None,
|
233 |
+
sincos_pos_emb: bool = True,
|
234 |
+
image_size: Union[int, Tuple[int]] = 224):
|
235 |
+
|
236 |
+
super().__init__()
|
237 |
+
self.num_channels = num_channels
|
238 |
+
self.patch_size = pair(patch_size)
|
239 |
+
self.dim_tokens = dim_tokens
|
240 |
+
self.sincos_pos_emb = sincos_pos_emb
|
241 |
+
self.image_size = pair(image_size)
|
242 |
+
self.num_patches = (self.image_size[0] // patch_size) * (self.image_size[1] // patch_size)
|
243 |
+
|
244 |
+
if self.dim_tokens is not None:
|
245 |
+
self.init(dim_tokens=dim_tokens)
|
246 |
+
|
247 |
+
def init(self, dim_tokens: int = 768, init_std=0.02):
|
248 |
+
"""
|
249 |
+
Initialize parts of encoder that are dependent on dimension of tokens.
|
250 |
+
Should be called when setting up FourM.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
dim_tokens: Dimension of tokens
|
254 |
+
init_std: Standard deviation of init
|
255 |
+
"""
|
256 |
+
self.dim_tokens = dim_tokens
|
257 |
+
|
258 |
+
# Task embedding identifying from which task a given token comes from
|
259 |
+
# Fixed-size positional embeddings. Can be interpolated to different input sizes
|
260 |
+
h_posemb = self.image_size[0] // self.patch_size[0]
|
261 |
+
w_posemb = self.image_size[1] // self.patch_size[1]
|
262 |
+
if self.sincos_pos_emb:
|
263 |
+
pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens)
|
264 |
+
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
|
265 |
+
else:
|
266 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens))
|
267 |
+
nn.init.normal_(self.pos_emb, std=init_std)
|
268 |
+
|
269 |
+
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
|
270 |
+
nn.init.normal_(self.mod_emb, std=init_std)
|
271 |
+
|
272 |
+
# Image -> tokens projection
|
273 |
+
# No bias term here, so modality embedding fully comes from self.mod_emb
|
274 |
+
self.proj = nn.Linear(self.num_channels * self.patch_size[0] * self.patch_size[1], self.dim_tokens, bias=False)
|
275 |
+
|
276 |
+
@torch.jit.ignore
|
277 |
+
def no_weight_decay(self):
|
278 |
+
return set()
|
279 |
+
|
280 |
+
def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
281 |
+
"""
|
282 |
+
Forward pass through embedding module, transforming image to sequence of tokens.
|
283 |
+
Creates corresponding modality and positional embeddings and adds them to the dict.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
d (Dict[str, torch.Tensor]): Modality dict with at least the following key:
|
287 |
+
- 'tensor' (torch.Tensor): Input image for each batch. Shape (B, C, H, W) where B is the batch size, C is the number of channels, and H, W are height and width of the image.
|
288 |
+
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
Dict[str, torch.Tensor]: Modality dict with added keys:
|
292 |
+
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, (H / PH) * (W / PW), D), where PH and PW are the patch sizes
|
293 |
+
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, (H / PH) * (W / PW), D)
|
294 |
+
"""
|
295 |
+
x = d['tensor']
|
296 |
+
B, C, H, W = x.shape
|
297 |
+
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
|
298 |
+
assert (H % self.patch_size[0] == 0) and (W % self.patch_size[1] == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.patch_size[0]}x{self.patch_size[1]}'
|
299 |
+
|
300 |
+
# Create patches [B, C, H, W] -> [B, (H*W), C]
|
301 |
+
x_patch = self.proj(rearrange(x, 'b d (nh ph) (nw pw) -> b (nh nw) (ph pw d)', ph=self.patch_size[0], pw=self.patch_size[1]))
|
302 |
+
|
303 |
+
# Create positional embedding + modality embedding
|
304 |
+
x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B)
|
305 |
+
|
306 |
+
d['x'] = x_patch
|
307 |
+
d['emb'] = x_emb
|
308 |
+
|
309 |
+
return d
|
310 |
+
|
311 |
+
|
312 |
+
class SequenceEmbEncoderEmbedding(nn.Module):
|
313 |
+
"""Adapter for sequence emb inputs, like T5-XXL, CLIP text embeddings.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
max_length: Maximum number of tokens in the sequence
|
317 |
+
dim_tokens: Dimension of output tokens. Can be set using init method.
|
318 |
+
sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings
|
319 |
+
padding_idx: Padding index for word embedding
|
320 |
+
orig_emb_dim: Dimension of original embeddings
|
321 |
+
bottleneck_dim: Dimension of bottleneck layer
|
322 |
+
use_bottleneck: Set to True to use bottleneck layer
|
323 |
+
"""
|
324 |
+
def __init__(self,
|
325 |
+
max_length: int,
|
326 |
+
dim_tokens: Optional[int] = None,
|
327 |
+
sincos_pos_emb: bool = True,
|
328 |
+
max_sincos_pos_emb: int = 512,
|
329 |
+
padding_idx: int = 0,
|
330 |
+
orig_emb_dim: int = 4096,
|
331 |
+
bottleneck_dim: int = 64,
|
332 |
+
use_bottleneck: bool = False,
|
333 |
+
):
|
334 |
+
super().__init__()
|
335 |
+
self.max_length = max_length
|
336 |
+
self.dim_tokens = dim_tokens
|
337 |
+
self.sincos_pos_emb = sincos_pos_emb
|
338 |
+
self.padding_idx = padding_idx
|
339 |
+
self.max_sincos_pos_emb = max_sincos_pos_emb
|
340 |
+
self.orig_emb_dim = orig_emb_dim
|
341 |
+
self.use_bottleneck = use_bottleneck
|
342 |
+
if self.use_bottleneck:
|
343 |
+
self.bottleneck_dim = bottleneck_dim
|
344 |
+
|
345 |
+
if self.dim_tokens is not None:
|
346 |
+
self.init(dim_tokens=dim_tokens)
|
347 |
+
|
348 |
+
def init(self, dim_tokens: int = 768, init_std=0.02):
|
349 |
+
"""
|
350 |
+
Initialize parts of embedding module that are dependent on dimension of tokens.
|
351 |
+
Should be called when setting up FourM.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
dim_tokens: Dimension of tokens
|
355 |
+
init_std: Standard deviation of init
|
356 |
+
"""
|
357 |
+
self.dim_tokens = dim_tokens
|
358 |
+
|
359 |
+
# Task embedding identifying from which task a given token comes from
|
360 |
+
# Fixed-size positional embeddings. Can be interpolated to different input sizes
|
361 |
+
if self.sincos_pos_emb:
|
362 |
+
if self.max_length > self.max_sincos_pos_emb:
|
363 |
+
raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}")
|
364 |
+
pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length]
|
365 |
+
self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
|
366 |
+
else:
|
367 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens))
|
368 |
+
nn.init.normal_(self.pos_emb, std=init_std)
|
369 |
+
|
370 |
+
self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
|
371 |
+
nn.init.normal_(self.mod_emb, std=init_std)
|
372 |
+
|
373 |
+
# Token embedding projection
|
374 |
+
if self.use_bottleneck:
|
375 |
+
self.emb_proj = nn.Sequential(
|
376 |
+
nn.Linear(self.orig_emb_dim, self.bottleneck_dim),
|
377 |
+
nn.Linear(self.bottleneck_dim, self.dim_tokens),
|
378 |
+
)
|
379 |
+
else:
|
380 |
+
self.emb_proj = nn.Linear(self.orig_emb_dim, self.dim_tokens)
|
381 |
+
|
382 |
+
|
383 |
+
@torch.jit.ignore
|
384 |
+
def no_weight_decay(self):
|
385 |
+
return set()
|
386 |
+
|
387 |
+
def forward(self, d):
|
388 |
+
"""
|
389 |
+
Forward pass through embedding module, projecting original embeddings to the Transformer dimension.
|
390 |
+
Creates corresponding modality and positional embeddings and adds them to the dict.
|
391 |
+
|
392 |
+
Args:
|
393 |
+
d (Dict[str, torch.Tensor]): Modality dict with at least the following keys:
|
394 |
+
- 'tensor' (torch.Tensor): Input token sequence for each batch. Shape (B, L, E) where B is the batch size and L is the sequence length, and E is the dimension of the original embeddings.
|
395 |
+
- 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
|
396 |
+
|
397 |
+
Returns:
|
398 |
+
Dict[str, torch.Tensor]: Modality dict with added keys:
|
399 |
+
- 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the Transformer embedding dimension.
|
400 |
+
- 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, L, D).
|
401 |
+
"""
|
402 |
+
orig_emb = d['tensor']
|
403 |
+
B = orig_emb.shape[0]
|
404 |
+
assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
|
405 |
+
|
406 |
+
# Map to embedding
|
407 |
+
x = self.emb_proj(orig_emb)
|
408 |
+
|
409 |
+
expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B)
|
410 |
+
# Input pos encoding
|
411 |
+
input_mask = d['input_mask']
|
412 |
+
input_pos_id = (~input_mask).int().cumsum(dim=1) - 1
|
413 |
+
input_pos_id[input_mask] = 0
|
414 |
+
input_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(input_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2]))
|
415 |
+
input_pos_emb[input_mask] = 0
|
416 |
+
|
417 |
+
x_emb = input_pos_emb + self.mod_emb
|
418 |
+
|
419 |
+
d['x'] = x
|
420 |
+
d['emb'] = x_emb
|
421 |
+
return d
|
422 |
+
|
fourm/models/fm.py
ADDED
@@ -0,0 +1,1130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
import random
|
16 |
+
import copy
|
17 |
+
from functools import partial
|
18 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from einops import rearrange, repeat
|
22 |
+
from torch import nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
from fourm.utils.timm.registry import register_model
|
26 |
+
from huggingface_hub import PyTorchModelHubMixin
|
27 |
+
|
28 |
+
from .fm_utils import Block, DecoderBlock, LayerNorm
|
29 |
+
from fourm.data.modality_info import MODALITY_INFO
|
30 |
+
|
31 |
+
|
32 |
+
# Model definitions
|
33 |
+
__all__ = [
|
34 |
+
# GELU models
|
35 |
+
'fm_tiny_6e_6d_gelu',
|
36 |
+
'fm_small_8e_8d_gelu',
|
37 |
+
'fm_base_12e_12d_gelu',
|
38 |
+
'fm_large_24e_24d_gelu',
|
39 |
+
'fm_xlarge_24e_24d_gelu',
|
40 |
+
# SwiGLU models
|
41 |
+
'fm_tiny_6e_6d_swiglu_nobias',
|
42 |
+
'fm_small_8e_8d_swiglu_nobias',
|
43 |
+
'fm_base_12e_12d_swiglu_nobias',
|
44 |
+
'fm_large_24e_24d_swiglu_nobias',
|
45 |
+
'fm_xlarge_24e_24d_swiglu_nobias',
|
46 |
+
# SwiGLU + QKNorm models
|
47 |
+
'fm_base_12e_12d_swiglu_qknorm_nobias',
|
48 |
+
'fm_large_24e_24d_swiglu_qknorm_nobias',
|
49 |
+
'fm_xlarge_24e_24d_swiglu_qknorm_nobias',
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
class FourM(nn.Module):
|
55 |
+
"""4M model.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
encoder_embeddings: Dict of encoder embedding modules.
|
59 |
+
decoder_embeddings: Dict of decoder embedding modules.
|
60 |
+
modality_info: Dict containing modality information.
|
61 |
+
dim: Embedding dimension.
|
62 |
+
encoder_depth: Number of encoder blocks.
|
63 |
+
decoder_depth: Number of decoder blocks.
|
64 |
+
num_heads: Number of attention heads.
|
65 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
66 |
+
qkv_bias: If True, add a learnable bias to query, key, value projections.
|
67 |
+
proj_bias: If True, add a learnable bias to the last projection of the attention block.
|
68 |
+
mlp_bias: If True, add a learnable bias to linear layers in the MLP / feed-forward.
|
69 |
+
drop_path_rate_encoder: Stochastic depth rate for encoder.
|
70 |
+
drop_path_rate_decoder: Stochastic depth rate for decoder.
|
71 |
+
shared_drop_path: If True, shares drop path between encoder and decoder.
|
72 |
+
act_layer: Activation layer to be used.
|
73 |
+
norm_layer: Normalization layer to be used.
|
74 |
+
gated_mlp: If True, make the feedforward gated (e.g., SwiGLU).
|
75 |
+
qk_norm: If True, applies normalization to queries and keys (QKNorm).
|
76 |
+
decoder_causal_mask: If True, decoder will use a causal mask for all tokens.
|
77 |
+
decoder_sep_mask: If True, decoder attention is restricted to within each modality only.
|
78 |
+
num_register_tokens: Number of register tokens.
|
79 |
+
use_act_checkpoint: If True, use activation checkpoint for each block.
|
80 |
+
"""
|
81 |
+
def __init__(self,
|
82 |
+
encoder_embeddings: Dict[str, nn.Module],
|
83 |
+
decoder_embeddings: Dict[str, nn.Module],
|
84 |
+
modality_info: Dict[str, Any],
|
85 |
+
dim: int = 768,
|
86 |
+
encoder_depth: int = 12,
|
87 |
+
decoder_depth: int = 12,
|
88 |
+
num_heads: int = 12,
|
89 |
+
mlp_ratio: float = 4.0,
|
90 |
+
qkv_bias: bool = True,
|
91 |
+
proj_bias: bool = True,
|
92 |
+
mlp_bias: bool = True,
|
93 |
+
drop_path_rate_encoder: float = 0.0,
|
94 |
+
drop_path_rate_decoder: float = 0.0,
|
95 |
+
shared_drop_path: bool = False,
|
96 |
+
act_layer: nn.Module = nn.GELU,
|
97 |
+
norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6),
|
98 |
+
gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU
|
99 |
+
qk_norm: bool = False,
|
100 |
+
decoder_causal_mask: bool = False,
|
101 |
+
decoder_sep_mask: bool = True,
|
102 |
+
num_register_tokens: int = 0,
|
103 |
+
use_act_checkpoint: bool = False,
|
104 |
+
share_modality_embeddings: bool = True,
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.modality_info = modality_info
|
109 |
+
self.dim = dim
|
110 |
+
self.decoder_causal_mask = decoder_causal_mask
|
111 |
+
self.decoder_sep_mask = decoder_sep_mask
|
112 |
+
self.init_std = 0.02
|
113 |
+
self.use_act_checkpoint = use_act_checkpoint
|
114 |
+
self.num_register_tokens = num_register_tokens
|
115 |
+
|
116 |
+
|
117 |
+
# Encoder embeddings & init
|
118 |
+
self.encoder_modalities = set(encoder_embeddings.keys())
|
119 |
+
for emb in encoder_embeddings.values():
|
120 |
+
emb.init(dim_tokens=dim, init_std=self.init_std)
|
121 |
+
self.encoder_embeddings = nn.ModuleDict(encoder_embeddings)
|
122 |
+
|
123 |
+
# Decoder embeddings & init
|
124 |
+
self.decoder_modalities = set(decoder_embeddings.keys())
|
125 |
+
for emb in decoder_embeddings.values():
|
126 |
+
emb.init(dim_tokens=dim, init_std=self.init_std)
|
127 |
+
self.decoder_embeddings = nn.ModuleDict(decoder_embeddings)
|
128 |
+
|
129 |
+
# Share modality embeddings across the encoder and decoder embedding modules
|
130 |
+
if share_modality_embeddings:
|
131 |
+
self.share_modality_embeddings()
|
132 |
+
|
133 |
+
## Transformer encoder
|
134 |
+
if shared_drop_path:
|
135 |
+
dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth + decoder_depth)][:encoder_depth]
|
136 |
+
else:
|
137 |
+
dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth)] # stochastic depth decay rule
|
138 |
+
|
139 |
+
self.encoder = nn.ModuleList([
|
140 |
+
Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
|
141 |
+
drop_path=dpr_encoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm)
|
142 |
+
for i in range(encoder_depth)
|
143 |
+
])
|
144 |
+
self.encoder_norm = norm_layer(dim)
|
145 |
+
|
146 |
+
|
147 |
+
## Transformer decoder
|
148 |
+
if shared_drop_path:
|
149 |
+
dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, encoder_depth + decoder_depth)][encoder_depth:]
|
150 |
+
else:
|
151 |
+
dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, decoder_depth)] # stochastic depth decay rule
|
152 |
+
|
153 |
+
# Projection of encoder tokens before adding the embeddings again
|
154 |
+
self.decoder_proj_context = nn.Linear(dim, dim)
|
155 |
+
|
156 |
+
self.decoder = nn.ModuleList([
|
157 |
+
DecoderBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
|
158 |
+
drop_path=dpr_decoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm)
|
159 |
+
for i in range(decoder_depth)
|
160 |
+
])
|
161 |
+
self.decoder_norm = norm_layer(dim)
|
162 |
+
|
163 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, dim))
|
164 |
+
nn.init.normal_(self.mask_token, std=self.init_std)
|
165 |
+
|
166 |
+
# Additional register tokens that can be used by the encoder during fine-tuning
|
167 |
+
if self.num_register_tokens > 0:
|
168 |
+
self.register_tokens = nn.Parameter(torch.zeros(1, self.num_register_tokens, dim))
|
169 |
+
nn.init.normal_(self.register_tokens, std=self.init_std)
|
170 |
+
else:
|
171 |
+
self.register_tokens = None
|
172 |
+
|
173 |
+
# Weight init
|
174 |
+
self.init_weights()
|
175 |
+
|
176 |
+
def share_modality_embeddings(self):
|
177 |
+
"""Share modality embeddings across the encoder and decoder embedding modules."""
|
178 |
+
shared_modalities = self.encoder_modalities & self.decoder_modalities
|
179 |
+
for mod in shared_modalities:
|
180 |
+
self.decoder_embeddings[mod].mod_emb = self.encoder_embeddings[mod].mod_emb
|
181 |
+
|
182 |
+
def init_weights(self):
|
183 |
+
"""Weight initialization following MAE's initialization scheme"""
|
184 |
+
|
185 |
+
for name, m in self.named_modules():
|
186 |
+
# Skipping tokenizers to avoid reinitializing them
|
187 |
+
if "tokenizer" in name:
|
188 |
+
continue
|
189 |
+
# Linear
|
190 |
+
elif isinstance(m, nn.Linear):
|
191 |
+
if 'qkv' in name:
|
192 |
+
# treat the weights of Q, K, V separately
|
193 |
+
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
|
194 |
+
nn.init.uniform_(m.weight, -val, val)
|
195 |
+
elif 'kv' in name:
|
196 |
+
# treat the weights of K, V separately
|
197 |
+
val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
|
198 |
+
nn.init.uniform_(m.weight, -val, val)
|
199 |
+
else:
|
200 |
+
nn.init.xavier_uniform_(m.weight)
|
201 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
202 |
+
nn.init.constant_(m.bias, 0)
|
203 |
+
# LayerNorm
|
204 |
+
elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
|
205 |
+
nn.init.constant_(m.weight, 1.0)
|
206 |
+
if m.bias is not None:
|
207 |
+
nn.init.constant_(m.bias, 0)
|
208 |
+
# Embedding
|
209 |
+
elif isinstance(m, nn.Embedding):
|
210 |
+
nn.init.normal_(m.weight, std=self.init_std)
|
211 |
+
# Conv2d
|
212 |
+
elif isinstance(m, nn.Conv2d):
|
213 |
+
if '.proj' in name:
|
214 |
+
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
|
215 |
+
w = m.weight.data
|
216 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
217 |
+
|
218 |
+
def get_num_layers_encoder(self):
|
219 |
+
return len(self.encoder)
|
220 |
+
|
221 |
+
def get_num_layers_decoder(self):
|
222 |
+
return len(self.decoder)
|
223 |
+
|
224 |
+
def get_num_layers(self):
|
225 |
+
return self.get_num_layers_encoder() + self.get_num_layers_decoder()
|
226 |
+
|
227 |
+
@torch.jit.ignore
|
228 |
+
def no_weight_decay(self):
|
229 |
+
no_wd_set = set()
|
230 |
+
|
231 |
+
for mod, emb_module in self.encoder_embeddings.items():
|
232 |
+
if hasattr(emb_module, 'no_weight_decay'):
|
233 |
+
to_skip = emb_module.no_weight_decay()
|
234 |
+
to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip])
|
235 |
+
no_wd_set = no_wd_set | to_skip
|
236 |
+
|
237 |
+
for mod, emb_module in self.decoder_embeddings.items():
|
238 |
+
if hasattr(emb_module, 'no_weight_decay'):
|
239 |
+
to_skip = emb_module.no_weight_decay()
|
240 |
+
to_skip = set([f'decoder_embeddings.{mod}.{name}' for name in to_skip])
|
241 |
+
no_wd_set = no_wd_set | to_skip
|
242 |
+
|
243 |
+
return no_wd_set
|
244 |
+
|
245 |
+
def cat_encoder_tensors(self, mod_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor]:
|
246 |
+
"""Concatenate encoder tensors from different modalities.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
mod_dict (dict): A dictionary containing information for each modality.
|
250 |
+
Expected keys for each modality are 'x' (input tokens),
|
251 |
+
'emb' (embeddings), 'input_mask', etc.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
tuple:
|
255 |
+
- encoder_tokens_all (torch.Tensor): Concatenated encoder tokens from all modalities. Shape (B, O, D) where O is the total number of all encoder tokens.
|
256 |
+
- emb_all (torch.Tensor): Concatenated encoder embeddings from all modalities. Shape (B, O, D)
|
257 |
+
- encoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the encoder input (set to 0 for valid tokens, 1 otherwise). Shape (B, O)
|
258 |
+
- mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each encoder token. Shape (B, O)
|
259 |
+
"""
|
260 |
+
|
261 |
+
encoder_tokens_all = []
|
262 |
+
emb_all = []
|
263 |
+
encoder_mask_all = []
|
264 |
+
mod_mask_all = []
|
265 |
+
|
266 |
+
for mod, d in mod_dict.items():
|
267 |
+
encoder_tokens_all.append(d['x'])
|
268 |
+
emb_all.append(d['emb'])
|
269 |
+
encoder_mask_all.append(d['input_mask'])
|
270 |
+
mod_mask_all.append(torch.full_like(d['input_mask'], self.modality_info[mod]['id'], dtype=torch.int16))
|
271 |
+
|
272 |
+
encoder_tokens_all = torch.cat(encoder_tokens_all, dim=1)
|
273 |
+
emb_all = torch.cat(emb_all, dim=1)
|
274 |
+
encoder_mask_all = torch.cat(encoder_mask_all, dim=1)
|
275 |
+
mod_mask_all = torch.cat(mod_mask_all, dim=1)
|
276 |
+
|
277 |
+
return encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all
|
278 |
+
|
279 |
+
def cat_decoder_tensors(self, mod_dict: Dict[str, Dict[str, torch.Tensor]]) -> Tuple[torch.Tensor]:
|
280 |
+
"""Concatenate decoder tensors from different modalities.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
mod_dict (dict): A dictionary containing information for each modality.
|
284 |
+
Expected keys for each modality include 'x' (input tokens),
|
285 |
+
'ids' (target IDs), 'emb' (embeddings), 'target_mask', 'decoder_attention_mask', etc.
|
286 |
+
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
tuple:
|
290 |
+
- decoder_tokens_all (torch.Tensor): Concatenated decoder tokens from all modalities. Shape (B, P, D) where P is the total number of all decoder tokens.
|
291 |
+
- emb_all (torch.Tensor): Concatenated decoder embeddings from all modalities. Shape (B, P, D)
|
292 |
+
- decoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the decoder input / target (set to 0 for valid tokens, 1 otherwise). Shape (B, P)
|
293 |
+
- target_ids_all (torch.Tensor): Concatenated target IDs from all modalities. Shape (B, P)
|
294 |
+
- attention_mask_all (torch.Tensor): Concatenated attention masks in compressed format, needs to be passed to adapt_decoder_attention_mask() to obtain the final attention mask. Shape (B, P)
|
295 |
+
- mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each decoder token. Shape (B, P)
|
296 |
+
"""
|
297 |
+
|
298 |
+
decoder_tokens_all = []
|
299 |
+
target_ids_all = []
|
300 |
+
emb_all = []
|
301 |
+
decoder_mask_all = []
|
302 |
+
attention_mask_all = []
|
303 |
+
mod_mask_all = []
|
304 |
+
|
305 |
+
# Shuffle order in which modalities are provided (useful for modality causal mask)
|
306 |
+
mod_dict = {mod: d for mod, d in random.sample(mod_dict.items(), len(mod_dict))}
|
307 |
+
|
308 |
+
for mod, d in mod_dict.items():
|
309 |
+
if self.modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']:
|
310 |
+
# Important: This makes the assumption that the target sequence appears sequentially
|
311 |
+
# before sorting / gathering
|
312 |
+
decoder_tokens_all.append(d['x'][:, :-1])
|
313 |
+
target_ids_all.append(d['ids'][:, 1:]) # Shifted left
|
314 |
+
emb_all.append(d['emb'][:, :-1])
|
315 |
+
# Logical or with left shifting removes the last unmasked position
|
316 |
+
decoder_mask_all.append(torch.logical_or(d['target_mask'][:, 1:], d['target_mask'][:, :-1]))
|
317 |
+
# Add attention mask ids
|
318 |
+
attention_mask_all.append(d['decoder_attention_mask'][:, :-1])
|
319 |
+
mod_mask_all.append(torch.full_like(d['ids'][:, :-1], self.modality_info[mod]['id'], dtype=torch.int16))
|
320 |
+
else:
|
321 |
+
# Important: For 2d / image modalities, the decoder input tokens are replaced by the mask token
|
322 |
+
decoder_tokens_all.append(torch.zeros_like(d['x']) + self.mask_token) # Replace x by mask token
|
323 |
+
target_ids_all.append(d['ids'])
|
324 |
+
emb_all.append(d['emb'])
|
325 |
+
decoder_mask_all.append(d['target_mask'])
|
326 |
+
attention_mask_all.append(d['decoder_attention_mask'])
|
327 |
+
mod_mask_all.append(torch.full_like(d['ids'], self.modality_info[mod]['id'], dtype=torch.int16))
|
328 |
+
|
329 |
+
decoder_tokens_all = torch.cat(decoder_tokens_all, dim=1)
|
330 |
+
emb_all = torch.cat(emb_all, dim=1)
|
331 |
+
decoder_mask_all = torch.cat(decoder_mask_all, dim=1)
|
332 |
+
target_ids_all = torch.cat(target_ids_all, dim=1)
|
333 |
+
attention_mask_all = torch.cat(attention_mask_all, dim=1)
|
334 |
+
mod_mask_all = torch.cat(mod_mask_all, dim=1)
|
335 |
+
|
336 |
+
return decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, attention_mask_all, mod_mask_all
|
337 |
+
|
338 |
+
def forward_mask_encoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_encoder_tokens: int) -> Tuple[torch.Tensor]:
|
339 |
+
"""Concatenates and mask encoder tensors based on provided modality information.
|
340 |
+
|
341 |
+
This function consolidates encoder tokens from multiple modalities, then selects a specified number of them based on modality information (i.e. masking).
|
342 |
+
|
343 |
+
Args:
|
344 |
+
mod_dict (dict): Dictionary containing tensors for different modalities.
|
345 |
+
It is expected to have keys for each modality and values
|
346 |
+
containing the modalities' associated tensors.
|
347 |
+
num_encoder_tokens (int): Number of encoder tokens to retain after masking.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
tuple:
|
351 |
+
- encoder_tokens (torch.Tensor): Selected encoder tokens from all modalities. Shape (B, N, D) where N is the number of selected encoder tokens.
|
352 |
+
- encoder_emb (torch.Tensor): Corresponding embeddings for encoder tokens. Shape (B, N, D)
|
353 |
+
- encoder_mask (torch.Tensor): A boolean mask indicating which encoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N)
|
354 |
+
- mod_mask (torch.Tensor): An integer mask marking the modality type for each encoder token (with -1 indicating unassigned pad tokens). Shape (B, N)
|
355 |
+
|
356 |
+
Notes:
|
357 |
+
- If `num_register_tokens` is set and greater than 0, register tokens are added at the beginning of the sequence.
|
358 |
+
"""
|
359 |
+
B = list(mod_dict.values())[0]['tensor'].shape[0]
|
360 |
+
|
361 |
+
encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.cat_encoder_tensors(mod_dict)
|
362 |
+
|
363 |
+
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way
|
364 |
+
mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6
|
365 |
+
ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1)
|
366 |
+
# ids_restore = torch.argsort(ids_shuffle, dim=1)
|
367 |
+
ids_keep = ids_shuffle[:, :num_encoder_tokens]
|
368 |
+
|
369 |
+
encoder_tokens = torch.gather(encoder_tokens_all, dim=1,
|
370 |
+
index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2]))
|
371 |
+
encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
|
372 |
+
encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep)
|
373 |
+
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
|
374 |
+
|
375 |
+
if self.num_register_tokens > 0:
|
376 |
+
register_tokens = repeat(self.register_tokens, '() n d -> b n d', b=B)
|
377 |
+
# We add register tokens at the beginning of the sequence
|
378 |
+
encoder_tokens = torch.cat([register_tokens, encoder_tokens], dim=1)
|
379 |
+
encoder_emb = torch.cat([torch.zeros_like(register_tokens), encoder_emb], dim=1)
|
380 |
+
encoder_mask = torch.cat([torch.zeros((B, register_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1)
|
381 |
+
mod_mask = torch.cat([torch.full((B, register_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1)
|
382 |
+
|
383 |
+
encoder_tokens[encoder_mask] = 0.
|
384 |
+
encoder_emb[encoder_mask] = 0.
|
385 |
+
mod_mask[encoder_mask] = -1
|
386 |
+
# Mask could be of shape 'b n1 n2' but not needed for masked_fill
|
387 |
+
# This means this mask can then be re-used for decoder cross-attention
|
388 |
+
encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2')
|
389 |
+
|
390 |
+
return encoder_tokens, encoder_emb, encoder_mask, mod_mask
|
391 |
+
|
392 |
+
def forward_mask_decoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_decoder_tokens: int) -> Tuple[torch.Tensor]:
|
393 |
+
"""Concatenates and mask decoder tensors based on provided modality information.
|
394 |
+
|
395 |
+
This function consolidates decoder tokens from multiple modalities, selects a specified number of them based on modality information, and applies appropriate masking.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
mod_dict (dict): Dictionary containing tensors for different modalities.
|
399 |
+
It is expected to have keys for each modality and values
|
400 |
+
containing the modalities' associated tensors.
|
401 |
+
num_decoder_tokens (int): Number of decoder tokens to retain after masking.
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
tuple:
|
405 |
+
- decoder_tokens (torch.Tensor): Selected decoder tokens from all modalities. Shape (B, M, D) where M is the number of selected decoder tokens.
|
406 |
+
- decoder_emb (torch.Tensor): Corresponding embeddings for decoder tokens. Shape (B, M, D)
|
407 |
+
- decoder_mask (torch.Tensor): A boolean mask indicating which decoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, M)
|
408 |
+
- target_ids (torch.Tensor): IDs of the target tokens corresponding to the decoder tokens. Shape (B, M)
|
409 |
+
- decoder_attention_mask (torch.Tensor): Mask for the decoder self-attention layers. Shape (B, M, M)
|
410 |
+
- mod_mask (torch.Tensor): An integer mask marking the modality type for each decoder token (with -1 indicating unassigned pad tokens). Shape (B, M)
|
411 |
+
"""
|
412 |
+
# decoder_mask and target_mask are equivalent, we rename it here to harmonize with forward_mask_encoder
|
413 |
+
decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, decoder_attention_mask_all, mod_mask_all = self.cat_decoder_tensors(mod_dict)
|
414 |
+
|
415 |
+
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way
|
416 |
+
mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
|
417 |
+
ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1)
|
418 |
+
# ids_restore = torch.argsort(ids_shuffle, dim=1)
|
419 |
+
ids_keep = ids_shuffle[:, :num_decoder_tokens]
|
420 |
+
|
421 |
+
decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2]))
|
422 |
+
decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
|
423 |
+
decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
|
424 |
+
target_ids = torch.gather(target_ids_all, dim=1, index=ids_keep)
|
425 |
+
decoder_attention_mask = torch.gather(decoder_attention_mask_all, dim=1, index=ids_keep)
|
426 |
+
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
|
427 |
+
|
428 |
+
decoder_tokens[decoder_mask] = 0.
|
429 |
+
decoder_emb[decoder_mask] = 0.
|
430 |
+
target_ids[decoder_mask] = 0
|
431 |
+
decoder_attention_mask = self.adapt_decoder_attention_mask(decoder_attention_mask, mod_mask)
|
432 |
+
mod_mask[decoder_mask] = -1
|
433 |
+
|
434 |
+
# This means this mask can then be re-used for decoder cross-attention
|
435 |
+
decoder_mask = rearrange(decoder_mask, 'b n2 -> b 1 n2')
|
436 |
+
|
437 |
+
|
438 |
+
return decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, mod_mask
|
439 |
+
|
440 |
+
def adapt_decoder_attention_mask(self, decoder_attention_mask: torch.Tensor, mod_mask=Optional[torch.Tensor]) -> torch.Tensor:
|
441 |
+
"""
|
442 |
+
Transforms the compressed decoder attention mask to a full attention mask based on the specified constraints.
|
443 |
+
|
444 |
+
Args:
|
445 |
+
decoder_attention_mask (torch.Tensor): Initial attention mask indicating attention constraints. Shape (B, M) where M is the number of the decoder tokens.
|
446 |
+
mod_mask (torch.Tensor, optional): Modality mask to separate attention masks per modality. Shape (B, M)
|
447 |
+
|
448 |
+
Returns:
|
449 |
+
torch.Tensor: Adapted attention mask. Shape (B, M, M) where M is the number of the decoder tokens.
|
450 |
+
"""
|
451 |
+
B, N = decoder_attention_mask.shape
|
452 |
+
|
453 |
+
if self.decoder_causal_mask:
|
454 |
+
# For causal mode, tokens can only attend to preceding tokens and themselves.
|
455 |
+
causal_mask = torch.ones((N, N), dtype=torch.bool, device=decoder_attention_mask.device).triu(1)
|
456 |
+
causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B)
|
457 |
+
adapted_attention_mask = causal_mask
|
458 |
+
else:
|
459 |
+
# Cumulatively sum the attention mask to determine token-wise attention behavior.
|
460 |
+
# Examples:
|
461 |
+
# Mask [4, 0, 0, 0] -> Cumsum: [4, 4, 4, 4] -> All tokens attend to each other.
|
462 |
+
# Mask [1, 1, 1, 1] -> Cumsum: [1, 2, 3, 4] -> Strict autoregressive behavior.
|
463 |
+
# Mask [2, 0, 1, 1] -> Cumsum: [2, 2, 3, 4] -> Tokens 1 and 2 attend to each other, token 3 attends to tokens 1-3, and token 4 to all.
|
464 |
+
attention_arange = torch.arange(N, device=decoder_attention_mask.device)
|
465 |
+
attention_arange = repeat(attention_arange, "n2 -> b n1 n2", b=B, n1=N)
|
466 |
+
cumsum_mask = torch.cumsum(decoder_attention_mask, dim=-1)
|
467 |
+
cumsum_mask = rearrange(cumsum_mask, "b n -> b n 1")
|
468 |
+
adapted_attention_mask = (attention_arange >= cumsum_mask)
|
469 |
+
|
470 |
+
if self.decoder_sep_mask:
|
471 |
+
# Separate attention between tokens based on their modality using mod_mask.
|
472 |
+
sep_mask = repeat(mod_mask, "b n2 -> b n1 n2", n1=N) != repeat(mod_mask, "b n1 -> b n1 n2", n2=N)
|
473 |
+
adapted_attention_mask = adapted_attention_mask | sep_mask
|
474 |
+
|
475 |
+
return adapted_attention_mask
|
476 |
+
|
477 |
+
def forward_encoder(self,
|
478 |
+
x: torch.Tensor,
|
479 |
+
encoder_mask: torch.Tensor) -> torch.Tensor:
|
480 |
+
"""Forward pass for the encoder.
|
481 |
+
|
482 |
+
Args:
|
483 |
+
x (torch.Tensor): Encoder input tokens. Shape (B, N, D) where N is the number of encoder tokens.
|
484 |
+
encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N)
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
torch.Tensor: Encoder output. Shape (B, N, D)
|
488 |
+
"""
|
489 |
+
|
490 |
+
for blk in self.encoder:
|
491 |
+
x = blk(x, mask=encoder_mask)
|
492 |
+
|
493 |
+
x = self.encoder_norm(x)
|
494 |
+
|
495 |
+
return x
|
496 |
+
|
497 |
+
def forward_decoder(self,
|
498 |
+
y: torch.Tensor,
|
499 |
+
context: torch.Tensor,
|
500 |
+
encoder_mask: torch.Tensor,
|
501 |
+
decoder_attention_mask: torch.Tensor) -> torch.Tensor:
|
502 |
+
"""Forward pass for the decoder.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
y (torch.Tensor): Decoder input tokens. Shape (B, M, D).
|
506 |
+
context (torch.Tensor): Context for the decoder (i.e. encoder output). Shape (B, N, D).
|
507 |
+
encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N).
|
508 |
+
decoder_attention_mask (torch.Tensor): Decoder attention mask. Shape (B, M, M).
|
509 |
+
|
510 |
+
Returns:
|
511 |
+
torch.Tensor: Decoder output. Shape (B, M, D).
|
512 |
+
"""
|
513 |
+
|
514 |
+
for blk in self.decoder:
|
515 |
+
y = blk(y, context, sa_mask=decoder_attention_mask, xa_mask=encoder_mask)
|
516 |
+
|
517 |
+
y = self.decoder_norm(y)
|
518 |
+
|
519 |
+
return y
|
520 |
+
|
521 |
+
def forward_logits(self,
|
522 |
+
y: torch.Tensor,
|
523 |
+
decoder_mod_dict: Dict[str, Dict[str, torch.Tensor]],
|
524 |
+
decoder_mod_mask: torch.Tensor,
|
525 |
+
return_all_logits: bool = False) -> Dict[str, torch.Tensor]:
|
526 |
+
"""Forward computation of logits for each modality.
|
527 |
+
|
528 |
+
Args:
|
529 |
+
y (torch.Tensor): Decoder output. Shape (B, M, D).
|
530 |
+
decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
|
531 |
+
decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M).
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
Dict[str, torch.Tensor]: Dictionary of logits for each modality.
|
535 |
+
"""
|
536 |
+
|
537 |
+
mod_logits = {}
|
538 |
+
for mod, d in decoder_mod_dict.items():
|
539 |
+
idx = self.modality_info[mod]["id"]
|
540 |
+
if return_all_logits:
|
541 |
+
logits = self.decoder_embeddings[mod].forward_logits(y)
|
542 |
+
else:
|
543 |
+
logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
|
544 |
+
mod_logits[mod] = logits
|
545 |
+
return mod_logits
|
546 |
+
|
547 |
+
def forward_loss(self,
|
548 |
+
y: torch.Tensor,
|
549 |
+
target_ids: torch.Tensor,
|
550 |
+
decoder_mod_dict: Dict[str, Any],
|
551 |
+
decoder_mod_mask: torch.Tensor, loss_type: str) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
552 |
+
"""Computes the loss based on the specified loss type.
|
553 |
+
|
554 |
+
Args:
|
555 |
+
y (torch.Tensor): Decoder output. Shape (B, M, D).
|
556 |
+
target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
|
557 |
+
decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
|
558 |
+
decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M).
|
559 |
+
loss_type (str): The type of loss to compute. Either 'mod' or 'token'.
|
560 |
+
|
561 |
+
Returns:
|
562 |
+
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total loss and dictionary of loss for each modality.
|
563 |
+
"""
|
564 |
+
if loss_type in ['mod', 'modality']:
|
565 |
+
loss, mod_loss = self.forward_mod_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask)
|
566 |
+
elif loss_type == 'token':
|
567 |
+
loss, mod_loss = self.forward_token_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask)
|
568 |
+
else:
|
569 |
+
raise ValueError("Invalid loss type")
|
570 |
+
|
571 |
+
return loss, mod_loss
|
572 |
+
|
573 |
+
def forward_mod_loss(self,
|
574 |
+
y: torch.Tensor,
|
575 |
+
target_ids: torch.Tensor,
|
576 |
+
decoder_mod_dict: Dict[str, Any],
|
577 |
+
decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
578 |
+
"""Computes the modality-wise loss.
|
579 |
+
|
580 |
+
Args:
|
581 |
+
y (torch.Tensor): Decoder tokens. Shape (B, M, D).
|
582 |
+
target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
|
583 |
+
decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
|
584 |
+
decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M).
|
585 |
+
|
586 |
+
Returns:
|
587 |
+
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total modality loss and dictionary of loss for each modality.
|
588 |
+
"""
|
589 |
+
mod_loss = {}
|
590 |
+
for mod, d in decoder_mod_dict.items():
|
591 |
+
idx = self.modality_info[mod]["id"]
|
592 |
+
logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
|
593 |
+
if logits.numel() == 0:
|
594 |
+
# If there are no logits / targets, set mod_loss to 0
|
595 |
+
mod_loss[mod] = torch.zeros(1, device=logits.device)
|
596 |
+
else:
|
597 |
+
loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean')
|
598 |
+
mod_loss[mod] = loss
|
599 |
+
|
600 |
+
loss = sum(mod_loss.values()) / len(mod_loss)
|
601 |
+
|
602 |
+
return loss, mod_loss
|
603 |
+
|
604 |
+
def forward_token_loss(self,
|
605 |
+
y: torch.Tensor,
|
606 |
+
target_ids: torch.Tensor,
|
607 |
+
decoder_mod_dict: Dict[str, Any],
|
608 |
+
decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
609 |
+
"""Computes the token-wise loss.
|
610 |
+
|
611 |
+
Args:
|
612 |
+
y (torch.Tensor): Decoder tokens. Shape (B, M, D).
|
613 |
+
target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
|
614 |
+
decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
|
615 |
+
decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M).
|
616 |
+
|
617 |
+
Returns:
|
618 |
+
Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total token loss and dictionary of loss for each modality.
|
619 |
+
"""
|
620 |
+
mod_loss = {}
|
621 |
+
mod_count = {}
|
622 |
+
|
623 |
+
for mod, d in decoder_mod_dict.items():
|
624 |
+
idx = self.modality_info[mod]["id"]
|
625 |
+
logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
|
626 |
+
if logits.numel() == 0:
|
627 |
+
# If there are no logits / targets, set mod_loss to 0
|
628 |
+
mod_loss[mod] = torch.zeros(1, device=logits.device)
|
629 |
+
mod_count[mod] = 0
|
630 |
+
else:
|
631 |
+
loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean')
|
632 |
+
mod_loss[mod] = loss
|
633 |
+
mod_count[mod] = logits.numel()
|
634 |
+
|
635 |
+
loss = sum([mod_loss[mod] * mod_count[mod] for mod in mod_loss.keys()]) / sum(mod_count.values())
|
636 |
+
|
637 |
+
return loss, mod_loss
|
638 |
+
|
639 |
+
|
640 |
+
def forward(self,
|
641 |
+
mod_dict: Dict[str, Dict[str, torch.Tensor]],
|
642 |
+
num_encoder_tokens: int,
|
643 |
+
num_decoder_tokens: int,
|
644 |
+
loss_type: str = 'mod',
|
645 |
+
return_logits: bool = False) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
646 |
+
"""
|
647 |
+
Forward pass for the model.
|
648 |
+
|
649 |
+
Args:
|
650 |
+
mod_dict (Dict[str, Dict[str, torch.Tensor]]): Dictionary containing the tensors, masks, and other info for each modality.
|
651 |
+
- mod_dict[modality_name]["tensor_name"]: Shape can vary based on tensor_name and modality.
|
652 |
+
num_encoder_tokens (int): Number of tokens to keep for the encoder.
|
653 |
+
num_decoder_tokens (int): Number of tokens to keep for the decoder.
|
654 |
+
loss_type (str, optional): The type of loss to compute. Can be 'mod' (average of loss per modality) or 'token' (average loss per token). Default is 'mod'.
|
655 |
+
return_logits (bool, optional): If True, return the logits. Default is False.
|
656 |
+
|
657 |
+
Returns:
|
658 |
+
Union[dict, tuple]:
|
659 |
+
- If return_logits is True: Dictionary of logits for each modality.
|
660 |
+
- Otherwise: Tuple containing the total loss and dictionary of loss for each modality.
|
661 |
+
"""
|
662 |
+
|
663 |
+
# Mod dicts
|
664 |
+
encoder_mod_dict = {mod: self.encoder_embeddings[mod](d)
|
665 |
+
for mod, d in mod_dict.items()
|
666 |
+
if mod in self.encoder_embeddings}
|
667 |
+
encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder(encoder_mod_dict, num_encoder_tokens)
|
668 |
+
|
669 |
+
decoder_mod_dict = {mod: self.decoder_embeddings[mod].forward_embed(d)
|
670 |
+
for mod, d in mod_dict.items()
|
671 |
+
if mod in self.decoder_embeddings}
|
672 |
+
decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask = self.forward_mask_decoder(decoder_mod_dict, num_decoder_tokens)
|
673 |
+
|
674 |
+
# Encoder
|
675 |
+
x = encoder_tokens + encoder_emb
|
676 |
+
x = self.forward_encoder(x, encoder_mask=encoder_mask)
|
677 |
+
|
678 |
+
# Decoder
|
679 |
+
context = self.decoder_proj_context(x) + encoder_emb
|
680 |
+
y = decoder_tokens + decoder_emb
|
681 |
+
y = self.forward_decoder(y, context, encoder_mask=encoder_mask, decoder_attention_mask=decoder_attention_mask)
|
682 |
+
|
683 |
+
# Logits
|
684 |
+
if return_logits:
|
685 |
+
mod_logits = self.forward_logits(y, decoder_mod_dict, decoder_mod_mask, return_all_logits=True)
|
686 |
+
return mod_logits
|
687 |
+
|
688 |
+
# Loss
|
689 |
+
loss, mod_loss = self.forward_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask, loss_type)
|
690 |
+
|
691 |
+
return loss, mod_loss
|
692 |
+
|
693 |
+
|
694 |
+
def freeze_encoder(self, freeze_embeddings=True):
|
695 |
+
for param in self.encoder.parameters():
|
696 |
+
param.requires_grad = False
|
697 |
+
|
698 |
+
for param in self.encoder_norm.parameters():
|
699 |
+
param.requires_grad = False
|
700 |
+
|
701 |
+
if freeze_embeddings:
|
702 |
+
for param in self.encoder_embeddings.parameters():
|
703 |
+
param.requires_grad = False
|
704 |
+
|
705 |
+
def freeze_encoder_except_specific_embeddings(self, frozen_embedding_domain):
|
706 |
+
frozen_embedding_domain = frozen_embedding_domain.split('-')
|
707 |
+
for param in self.encoder.parameters():
|
708 |
+
param.requires_grad = False
|
709 |
+
|
710 |
+
for param in self.encoder_norm.parameters():
|
711 |
+
param.requires_grad = False
|
712 |
+
|
713 |
+
for name, param in self.encoder_embeddings.named_parameters():
|
714 |
+
if name.split('.')[0] in frozen_embedding_domain:
|
715 |
+
param.requires_grad = False
|
716 |
+
|
717 |
+
def unfreeze_encoder(self, unfreeze_embeddings=True):
|
718 |
+
for param in self.encoder.parameters():
|
719 |
+
param.requires_grad = True
|
720 |
+
|
721 |
+
for param in self.encoder_norm.parameters():
|
722 |
+
param.requires_grad = True
|
723 |
+
|
724 |
+
if unfreeze_embeddings:
|
725 |
+
for param in self.encoder_embeddings.parameters():
|
726 |
+
param.requires_grad = True
|
727 |
+
|
728 |
+
def freeze_decoder(self, freeze_embeddings=True):
|
729 |
+
for param in self.decoder.parameters():
|
730 |
+
param.requires_grad = False
|
731 |
+
|
732 |
+
for param in self.decoder_norm.parameters():
|
733 |
+
param.requires_grad = False
|
734 |
+
|
735 |
+
if freeze_embeddings:
|
736 |
+
for param in self.decoder_embeddings.parameters():
|
737 |
+
param.requires_grad = False
|
738 |
+
|
739 |
+
def freeze_decoder_except_specific_embeddings(self, frozen_embedding_domain):
|
740 |
+
frozen_embedding_domain = frozen_embedding_domain.split('-')
|
741 |
+
for param in self.decoder.parameters():
|
742 |
+
param.requires_grad = False
|
743 |
+
|
744 |
+
for param in self.decoder_norm.parameters():
|
745 |
+
param.requires_grad = False
|
746 |
+
|
747 |
+
for name, param in self.decoder_embeddings.named_parameters():
|
748 |
+
if name.split('.')[0] in frozen_embedding_domain:
|
749 |
+
param.requires_grad = False
|
750 |
+
|
751 |
+
def unfreeze_decoder(self, unfreeze_embeddings=True):
|
752 |
+
for param in self.decoder.parameters():
|
753 |
+
param.requires_grad = True
|
754 |
+
|
755 |
+
for param in self.decoder_norm.parameters():
|
756 |
+
param.requires_grad = True
|
757 |
+
|
758 |
+
if unfreeze_embeddings:
|
759 |
+
for param in self.decoder_embeddings.parameters():
|
760 |
+
param.requires_grad = True
|
761 |
+
|
762 |
+
def freeze_shared_params(self):
|
763 |
+
self.freeze_encoder(freeze_embeddings=False)
|
764 |
+
self.freeze_decoder(freeze_embeddings=False)
|
765 |
+
|
766 |
+
def freeze_params_except_specific_embeddings(self, frozen_embedding_domain):
|
767 |
+
self.freeze_encoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain)
|
768 |
+
self.freeze_decoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain)
|
769 |
+
|
770 |
+
def unfreeze_shared_params(self):
|
771 |
+
self.unfreeze_encoder(unfreeze_embeddings=False)
|
772 |
+
self.unfreeze_decoder(unfreeze_embeddings=False)
|
773 |
+
|
774 |
+
def unfreeze_all(self):
|
775 |
+
self.unfreeze_encoder(unfreeze_embeddings=True)
|
776 |
+
self.unfreeze_decoder(unfreeze_embeddings=True)
|
777 |
+
|
778 |
+
|
779 |
+
################################################
|
780 |
+
|
781 |
+
# Wrapper for easy loading with Huggingface Hub
|
782 |
+
|
783 |
+
class FM(FourM, PyTorchModelHubMixin):
|
784 |
+
"""Wrapper around FourM for easy loading with Huggingface Hub.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
config (dict): Dictionary containing the model and modality configuration,
|
788 |
+
used for loading from Huggingface Hub.
|
789 |
+
"""
|
790 |
+
def __init__(self, config: dict):
|
791 |
+
|
792 |
+
config = copy.deepcopy(config)
|
793 |
+
|
794 |
+
all_domains = sorted(list(set(config['domains_in']) | set(config['domains_out'])))
|
795 |
+
modality_info = {mod: MODALITY_INFO[mod] for mod in all_domains}
|
796 |
+
|
797 |
+
encoder_embeddings = {}
|
798 |
+
for mod in config['domains_in']:
|
799 |
+
info = modality_info[mod]
|
800 |
+
if info.get("encoder_embedding", None) is not None:
|
801 |
+
if info["type"] == "img":
|
802 |
+
image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size'])
|
803 |
+
encoder_embeddings[mod] = info["encoder_embedding"](patch_size=patch_size, image_size=image_size)
|
804 |
+
else:
|
805 |
+
encoder_embeddings[mod] = info["encoder_embedding"]()
|
806 |
+
|
807 |
+
decoder_embeddings = {}
|
808 |
+
for mod in config['domains_out']:
|
809 |
+
info = modality_info[mod]
|
810 |
+
if info.get("decoder_embedding", None) is not None:
|
811 |
+
if info["type"] == "img":
|
812 |
+
image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size'])
|
813 |
+
decoder_embeddings[mod] = info["decoder_embedding"](patch_size=patch_size, image_size=image_size, share_embedding=False)
|
814 |
+
else:
|
815 |
+
decoder_embeddings[mod] = info["decoder_embedding"](share_embedding=False)
|
816 |
+
|
817 |
+
config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias'])
|
818 |
+
config['act_layer'] = getattr(torch.nn, config['act_layer'])
|
819 |
+
|
820 |
+
del config['norm_bias']
|
821 |
+
del config['domains_in']
|
822 |
+
del config['domains_out']
|
823 |
+
del config['image_size']
|
824 |
+
del config['patch_size']
|
825 |
+
|
826 |
+
super().__init__(
|
827 |
+
encoder_embeddings=encoder_embeddings,
|
828 |
+
decoder_embeddings=decoder_embeddings,
|
829 |
+
modality_info=modality_info,
|
830 |
+
**config
|
831 |
+
)
|
832 |
+
|
833 |
+
|
834 |
+
################################################
|
835 |
+
|
836 |
+
# Model definitions
|
837 |
+
|
838 |
+
# GELU variants
|
839 |
+
@register_model
|
840 |
+
def fm_tiny_6e_6d_gelu(
|
841 |
+
encoder_embeddings: Dict[str, nn.Module],
|
842 |
+
decoder_embeddings: Dict[str, nn.Module],
|
843 |
+
**kwargs):
|
844 |
+
model = FourM(
|
845 |
+
encoder_embeddings=encoder_embeddings,
|
846 |
+
decoder_embeddings=decoder_embeddings,
|
847 |
+
encoder_depth=6,
|
848 |
+
decoder_depth=6,
|
849 |
+
dim=384,
|
850 |
+
num_heads=6,
|
851 |
+
mlp_ratio=4,
|
852 |
+
qkv_bias=True,
|
853 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
854 |
+
**kwargs
|
855 |
+
)
|
856 |
+
return model
|
857 |
+
|
858 |
+
|
859 |
+
@register_model
|
860 |
+
def fm_small_8e_8d_gelu(
|
861 |
+
encoder_embeddings: Dict[str, nn.Module],
|
862 |
+
decoder_embeddings: Dict[str, nn.Module],
|
863 |
+
**kwargs):
|
864 |
+
model = FourM(
|
865 |
+
encoder_embeddings=encoder_embeddings,
|
866 |
+
decoder_embeddings=decoder_embeddings,
|
867 |
+
encoder_depth=8,
|
868 |
+
decoder_depth=8,
|
869 |
+
dim=512,
|
870 |
+
num_heads=8,
|
871 |
+
mlp_ratio=4,
|
872 |
+
qkv_bias=True,
|
873 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
874 |
+
**kwargs
|
875 |
+
)
|
876 |
+
return model
|
877 |
+
|
878 |
+
|
879 |
+
@register_model
|
880 |
+
def fm_base_12e_12d_gelu(
|
881 |
+
encoder_embeddings: Dict[str, nn.Module],
|
882 |
+
decoder_embeddings: Dict[str, nn.Module],
|
883 |
+
**kwargs):
|
884 |
+
model = FourM(
|
885 |
+
encoder_embeddings=encoder_embeddings,
|
886 |
+
decoder_embeddings=decoder_embeddings,
|
887 |
+
encoder_depth=12,
|
888 |
+
decoder_depth=12,
|
889 |
+
dim=768,
|
890 |
+
num_heads=12,
|
891 |
+
mlp_ratio=4,
|
892 |
+
qkv_bias=True,
|
893 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
894 |
+
**kwargs
|
895 |
+
)
|
896 |
+
return model
|
897 |
+
|
898 |
+
|
899 |
+
@register_model
|
900 |
+
def fm_large_24e_24d_gelu(
|
901 |
+
encoder_embeddings: Dict[str, nn.Module],
|
902 |
+
decoder_embeddings: Dict[str, nn.Module],
|
903 |
+
**kwargs):
|
904 |
+
model = FourM(
|
905 |
+
encoder_embeddings=encoder_embeddings,
|
906 |
+
decoder_embeddings=decoder_embeddings,
|
907 |
+
encoder_depth=24,
|
908 |
+
decoder_depth=24,
|
909 |
+
dim=1024,
|
910 |
+
num_heads=16,
|
911 |
+
mlp_ratio=4,
|
912 |
+
qkv_bias=True,
|
913 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
914 |
+
**kwargs
|
915 |
+
)
|
916 |
+
return model
|
917 |
+
|
918 |
+
@register_model
|
919 |
+
def fm_xlarge_24e_24d_gelu(
|
920 |
+
encoder_embeddings: Dict[str, nn.Module],
|
921 |
+
decoder_embeddings: Dict[str, nn.Module],
|
922 |
+
**kwargs):
|
923 |
+
model = FourM(
|
924 |
+
encoder_embeddings=encoder_embeddings,
|
925 |
+
decoder_embeddings=decoder_embeddings,
|
926 |
+
encoder_depth=24,
|
927 |
+
decoder_depth=24,
|
928 |
+
dim=2048,
|
929 |
+
num_heads=32,
|
930 |
+
mlp_ratio=4,
|
931 |
+
qkv_bias=True,
|
932 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
933 |
+
**kwargs
|
934 |
+
)
|
935 |
+
return model
|
936 |
+
|
937 |
+
|
938 |
+
# SwiGLU variants
|
939 |
+
@register_model
|
940 |
+
def fm_tiny_6e_6d_swiglu_nobias(
|
941 |
+
encoder_embeddings: Dict[str, nn.Module],
|
942 |
+
decoder_embeddings: Dict[str, nn.Module],
|
943 |
+
**kwargs):
|
944 |
+
model = FourM(
|
945 |
+
encoder_embeddings=encoder_embeddings,
|
946 |
+
decoder_embeddings=decoder_embeddings,
|
947 |
+
encoder_depth=6,
|
948 |
+
decoder_depth=6,
|
949 |
+
dim=384,
|
950 |
+
num_heads=6,
|
951 |
+
mlp_ratio=4,
|
952 |
+
qkv_bias=False,
|
953 |
+
proj_bias=False,
|
954 |
+
mlp_bias=False,
|
955 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
956 |
+
act_layer=nn.SiLU,
|
957 |
+
gated_mlp=True,
|
958 |
+
**kwargs
|
959 |
+
)
|
960 |
+
return model
|
961 |
+
|
962 |
+
|
963 |
+
@register_model
|
964 |
+
def fm_small_8e_8d_swiglu_nobias(
|
965 |
+
encoder_embeddings: Dict[str, nn.Module],
|
966 |
+
decoder_embeddings: Dict[str, nn.Module],
|
967 |
+
**kwargs):
|
968 |
+
model = FourM(
|
969 |
+
encoder_embeddings=encoder_embeddings,
|
970 |
+
decoder_embeddings=decoder_embeddings,
|
971 |
+
encoder_depth=8,
|
972 |
+
decoder_depth=8,
|
973 |
+
dim=512,
|
974 |
+
num_heads=8,
|
975 |
+
mlp_ratio=4,
|
976 |
+
qkv_bias=False,
|
977 |
+
proj_bias=False,
|
978 |
+
mlp_bias=False,
|
979 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
980 |
+
act_layer=nn.SiLU,
|
981 |
+
gated_mlp=True,
|
982 |
+
**kwargs
|
983 |
+
)
|
984 |
+
return model
|
985 |
+
|
986 |
+
|
987 |
+
@register_model
|
988 |
+
def fm_base_12e_12d_swiglu_nobias(
|
989 |
+
encoder_embeddings: Dict[str, nn.Module],
|
990 |
+
decoder_embeddings: Dict[str, nn.Module],
|
991 |
+
**kwargs):
|
992 |
+
model = FourM(
|
993 |
+
encoder_embeddings=encoder_embeddings,
|
994 |
+
decoder_embeddings=decoder_embeddings,
|
995 |
+
encoder_depth=12,
|
996 |
+
decoder_depth=12,
|
997 |
+
dim=768,
|
998 |
+
num_heads=12,
|
999 |
+
mlp_ratio=4,
|
1000 |
+
qkv_bias=False,
|
1001 |
+
proj_bias=False,
|
1002 |
+
mlp_bias=False,
|
1003 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
1004 |
+
act_layer=nn.SiLU,
|
1005 |
+
gated_mlp=True,
|
1006 |
+
**kwargs
|
1007 |
+
)
|
1008 |
+
return model
|
1009 |
+
|
1010 |
+
@register_model
|
1011 |
+
def fm_large_24e_24d_swiglu_nobias(
|
1012 |
+
encoder_embeddings: Dict[str, nn.Module],
|
1013 |
+
decoder_embeddings: Dict[str, nn.Module],
|
1014 |
+
**kwargs):
|
1015 |
+
model = FourM(
|
1016 |
+
encoder_embeddings=encoder_embeddings,
|
1017 |
+
decoder_embeddings=decoder_embeddings,
|
1018 |
+
encoder_depth=24,
|
1019 |
+
decoder_depth=24,
|
1020 |
+
dim=1024,
|
1021 |
+
num_heads=16,
|
1022 |
+
mlp_ratio=4,
|
1023 |
+
qkv_bias=False,
|
1024 |
+
proj_bias=False,
|
1025 |
+
mlp_bias=False,
|
1026 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
1027 |
+
act_layer=nn.SiLU,
|
1028 |
+
gated_mlp=True,
|
1029 |
+
**kwargs
|
1030 |
+
)
|
1031 |
+
return model
|
1032 |
+
|
1033 |
+
@register_model
|
1034 |
+
def fm_xlarge_24e_24d_swiglu_nobias(
|
1035 |
+
encoder_embeddings: Dict[str, nn.Module],
|
1036 |
+
decoder_embeddings: Dict[str, nn.Module],
|
1037 |
+
**kwargs):
|
1038 |
+
model = FourM(
|
1039 |
+
encoder_embeddings=encoder_embeddings,
|
1040 |
+
decoder_embeddings=decoder_embeddings,
|
1041 |
+
encoder_depth=24,
|
1042 |
+
decoder_depth=24,
|
1043 |
+
dim=2048,
|
1044 |
+
num_heads=32,
|
1045 |
+
mlp_ratio=4,
|
1046 |
+
qkv_bias=False,
|
1047 |
+
proj_bias=False,
|
1048 |
+
mlp_bias=False,
|
1049 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
1050 |
+
act_layer=nn.SiLU,
|
1051 |
+
gated_mlp=True,
|
1052 |
+
**kwargs
|
1053 |
+
)
|
1054 |
+
return model
|
1055 |
+
|
1056 |
+
# SwiGLU + QKNorm variants
|
1057 |
+
|
1058 |
+
|
1059 |
+
@register_model
|
1060 |
+
def fm_base_12e_12d_swiglu_qknorm_nobias(
|
1061 |
+
encoder_embeddings: Dict[str, nn.Module],
|
1062 |
+
decoder_embeddings: Dict[str, nn.Module],
|
1063 |
+
**kwargs):
|
1064 |
+
model = FourM(
|
1065 |
+
encoder_embeddings=encoder_embeddings,
|
1066 |
+
decoder_embeddings=decoder_embeddings,
|
1067 |
+
encoder_depth=12,
|
1068 |
+
decoder_depth=12,
|
1069 |
+
dim=768,
|
1070 |
+
num_heads=12,
|
1071 |
+
mlp_ratio=4,
|
1072 |
+
qkv_bias=False,
|
1073 |
+
proj_bias=False,
|
1074 |
+
mlp_bias=False,
|
1075 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
1076 |
+
act_layer=nn.SiLU,
|
1077 |
+
gated_mlp=True,
|
1078 |
+
qk_norm=True,
|
1079 |
+
**kwargs
|
1080 |
+
)
|
1081 |
+
return model
|
1082 |
+
|
1083 |
+
|
1084 |
+
@register_model
|
1085 |
+
def fm_large_24e_24d_swiglu_qknorm_nobias(
|
1086 |
+
encoder_embeddings: Dict[str, nn.Module],
|
1087 |
+
decoder_embeddings: Dict[str, nn.Module],
|
1088 |
+
**kwargs):
|
1089 |
+
model = FourM(
|
1090 |
+
encoder_embeddings=encoder_embeddings,
|
1091 |
+
decoder_embeddings=decoder_embeddings,
|
1092 |
+
encoder_depth=24,
|
1093 |
+
decoder_depth=24,
|
1094 |
+
dim=1024,
|
1095 |
+
num_heads=16,
|
1096 |
+
mlp_ratio=4,
|
1097 |
+
qkv_bias=False,
|
1098 |
+
proj_bias=False,
|
1099 |
+
mlp_bias=False,
|
1100 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
1101 |
+
act_layer=nn.SiLU,
|
1102 |
+
gated_mlp=True,
|
1103 |
+
qk_norm=True,
|
1104 |
+
**kwargs
|
1105 |
+
)
|
1106 |
+
return model
|
1107 |
+
|
1108 |
+
@register_model
|
1109 |
+
def fm_xlarge_24e_24d_swiglu_qknorm_nobias(
|
1110 |
+
encoder_embeddings: Dict[str, nn.Module],
|
1111 |
+
decoder_embeddings: Dict[str, nn.Module],
|
1112 |
+
**kwargs):
|
1113 |
+
model = FourM(
|
1114 |
+
encoder_embeddings=encoder_embeddings,
|
1115 |
+
decoder_embeddings=decoder_embeddings,
|
1116 |
+
encoder_depth=24,
|
1117 |
+
decoder_depth=24,
|
1118 |
+
dim=2048,
|
1119 |
+
num_heads=32,
|
1120 |
+
mlp_ratio=4,
|
1121 |
+
qkv_bias=False,
|
1122 |
+
proj_bias=False,
|
1123 |
+
mlp_bias=False,
|
1124 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
1125 |
+
act_layer=nn.SiLU,
|
1126 |
+
gated_mlp=True,
|
1127 |
+
qk_norm=True,
|
1128 |
+
**kwargs
|
1129 |
+
)
|
1130 |
+
return model
|
fourm/models/fm_utils.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# Some functions are based on the timm code base
|
16 |
+
# https://github.com/huggingface/pytorch-image-models
|
17 |
+
# --------------------------------------------------------
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from einops import rearrange
|
23 |
+
|
24 |
+
|
25 |
+
def pair(t):
|
26 |
+
return t if isinstance(t, tuple) else (t, t)
|
27 |
+
|
28 |
+
def softmax1(tensor):
|
29 |
+
# See https://www.evanmiller.org/attention-is-off-by-one.html
|
30 |
+
return F.pad(tensor, (0,1)).softmax(dim=-1)[...,:-1]
|
31 |
+
|
32 |
+
def build_1d_sincos_posemb(max_len, embed_dim=1024, temperature=10000.):
|
33 |
+
"""Sine-cosine positional embeddings from MoCo-v3, adapted back to 1d
|
34 |
+
|
35 |
+
Returns positional embedding of shape (1, N, D)
|
36 |
+
"""
|
37 |
+
arange = torch.arange(max_len, dtype=torch.float32) # Shape (N,)
|
38 |
+
assert embed_dim % 2 == 0, 'Embed dimension must be divisible by 2 for 1D sin-cos position embedding'
|
39 |
+
pos_dim = embed_dim // 2
|
40 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim # Shape (D/2,)
|
41 |
+
omega = 1. / (temperature ** omega)
|
42 |
+
out = torch.einsum('n,d->nd', [arange, omega]) # Outer product, shape (N, D/2)
|
43 |
+
pos_emb = torch.cat([torch.sin(out), torch.cos(out)], dim=1).unsqueeze(0) # Shape (1, N, D)
|
44 |
+
return pos_emb
|
45 |
+
|
46 |
+
def build_2d_sincos_posemb(h, w, embed_dim=1024, temperature=10000.0):
|
47 |
+
"""Sine-cosine positional embeddings as used in MoCo-v3
|
48 |
+
|
49 |
+
Returns positional embedding of shape (1, N, D) where N = W*H
|
50 |
+
"""
|
51 |
+
grid_w = torch.arange(w, dtype=torch.float32) # Shape (W,)
|
52 |
+
grid_h = torch.arange(h, dtype=torch.float32) # Shape (H, )
|
53 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') # Shapes (W, H)
|
54 |
+
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
55 |
+
pos_dim = embed_dim // 4
|
56 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim # Shape (D/4,)
|
57 |
+
omega = 1. / (temperature ** omega)
|
58 |
+
out_w = torch.einsum('n,d->nd', [grid_w.reshape(-1), omega]) # Outer product, shape (W*H, D/4)
|
59 |
+
out_h = torch.einsum('n,d->nd', [grid_h.reshape(-1), omega]) # Outer product, shape (W*H, D/4)
|
60 |
+
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1).unsqueeze(0) # Shape (1, W*H, D)
|
61 |
+
return pos_emb
|
62 |
+
|
63 |
+
|
64 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
65 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
66 |
+
Implementation from timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
67 |
+
"""
|
68 |
+
if drop_prob == 0. or not training:
|
69 |
+
return x
|
70 |
+
keep_prob = 1 - drop_prob
|
71 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
72 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
73 |
+
random_tensor.floor_() # binarize
|
74 |
+
output = x.div(keep_prob) * random_tensor
|
75 |
+
return output
|
76 |
+
|
77 |
+
|
78 |
+
class DropPath(nn.Module):
|
79 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, drop_prob=None):
|
83 |
+
super(DropPath, self).__init__()
|
84 |
+
self.drop_prob = drop_prob
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
return drop_path(x, self.drop_prob, self.training)
|
88 |
+
|
89 |
+
def extra_repr(self) -> str:
|
90 |
+
return 'p={}'.format(self.drop_prob)
|
91 |
+
|
92 |
+
|
93 |
+
class LayerNorm(nn.Module):
|
94 |
+
"""Custom implementation of LayerNorm with the option to disable the bias term"""
|
95 |
+
def __init__(self, normalized_shape: int, eps=1e-5, bias=True):
|
96 |
+
super().__init__()
|
97 |
+
self.eps = eps
|
98 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
99 |
+
if bias:
|
100 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
101 |
+
else:
|
102 |
+
self.register_buffer("bias", torch.zeros(normalized_shape))
|
103 |
+
|
104 |
+
# Normalized shape must be a tuple for F.layer_norm
|
105 |
+
self.normalized_shape = (normalized_shape,)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, eps=self.eps)
|
109 |
+
|
110 |
+
|
111 |
+
class Mlp(nn.Module):
|
112 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., bias=True):
|
113 |
+
super().__init__()
|
114 |
+
out_features = out_features or in_features
|
115 |
+
hidden_features = hidden_features or in_features
|
116 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
117 |
+
self.act = act_layer()
|
118 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
119 |
+
self.drop = nn.Dropout(drop)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = self.fc1(x)
|
123 |
+
x = self.act(x)
|
124 |
+
x = self.fc2(x)
|
125 |
+
x = self.drop(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
class GatedMlp(nn.Module):
|
130 |
+
"""Implements SwiGLU and other gated feed-forward layers from Noam Shazeer's paper: https://arxiv.org/abs/2002.05202
|
131 |
+
"""
|
132 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, bias=True):
|
133 |
+
super().__init__()
|
134 |
+
out_features = out_features or in_features
|
135 |
+
# If gated, multiply hidden_dim by 2/3 to account for extra matmul
|
136 |
+
hidden_features = int(2 * (hidden_features or in_features) / 3)
|
137 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
138 |
+
self.act = act_layer()
|
139 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
140 |
+
self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
x = self.fc2(self.act(self.fc1(x)) * self.fc3(x))
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class Attention(nn.Module):
|
148 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, attn_drop=0., proj_drop=0., allow_zero_attn=False):
|
149 |
+
super().__init__()
|
150 |
+
self.num_heads = num_heads
|
151 |
+
head_dim = dim // num_heads
|
152 |
+
self.scale = head_dim ** -0.5
|
153 |
+
self.allow_zero_attn = allow_zero_attn
|
154 |
+
|
155 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
156 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
157 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
158 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
159 |
+
|
160 |
+
def forward(self, x, mask=None):
|
161 |
+
B, N, C = x.shape
|
162 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
163 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
164 |
+
|
165 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
166 |
+
|
167 |
+
if mask is not None:
|
168 |
+
mask = mask.unsqueeze(1) # Unsqueeze attention mask for multi-head
|
169 |
+
attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
|
170 |
+
|
171 |
+
if self.allow_zero_attn:
|
172 |
+
attn = softmax1(attn)
|
173 |
+
else:
|
174 |
+
attn = attn.softmax(dim=-1)
|
175 |
+
attn = self.attn_drop(attn)
|
176 |
+
|
177 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
178 |
+
x = self.proj(x)
|
179 |
+
x = self.proj_drop(x)
|
180 |
+
return x
|
181 |
+
|
182 |
+
class CrossAttention(nn.Module):
|
183 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, attn_drop=0., proj_drop=0., allow_zero_attn=False):
|
184 |
+
super().__init__()
|
185 |
+
self.num_heads = num_heads
|
186 |
+
head_dim = dim // num_heads
|
187 |
+
self.scale = head_dim ** -0.5
|
188 |
+
self.allow_zero_attn = allow_zero_attn
|
189 |
+
|
190 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
191 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
192 |
+
|
193 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
194 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
195 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
196 |
+
|
197 |
+
def forward(self, x, context, mask=None):
|
198 |
+
B, N, C = x.shape
|
199 |
+
_, M, _ = context.shape
|
200 |
+
|
201 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
202 |
+
kv = self.kv(context).reshape(B, M, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
203 |
+
k, v = kv[0], kv[1]
|
204 |
+
|
205 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
206 |
+
if mask is not None:
|
207 |
+
mask = rearrange(mask, "b n m -> b 1 n m") # Unsqueeze / reshape for multi-head
|
208 |
+
attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
|
209 |
+
|
210 |
+
if self.allow_zero_attn:
|
211 |
+
attn = softmax1(attn)
|
212 |
+
else:
|
213 |
+
attn = attn.softmax(dim=-1)
|
214 |
+
attn = self.attn_drop(attn)
|
215 |
+
|
216 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
217 |
+
x = self.proj(x)
|
218 |
+
x = self.proj_drop(x)
|
219 |
+
return x
|
220 |
+
|
221 |
+
|
222 |
+
class NormAttention(nn.Module):
|
223 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, norm_layer=nn.LayerNorm, attn_drop=0., proj_drop=0., allow_zero_attn=False):
|
224 |
+
super().__init__()
|
225 |
+
self.num_heads = num_heads
|
226 |
+
head_dim = dim // num_heads
|
227 |
+
self.scale = head_dim ** -0.5
|
228 |
+
self.allow_zero_attn = allow_zero_attn
|
229 |
+
|
230 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
231 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
232 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
233 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
234 |
+
|
235 |
+
self.q_norm = norm_layer(head_dim)
|
236 |
+
self.k_norm = norm_layer(head_dim)
|
237 |
+
|
238 |
+
def forward(self, x, mask=None):
|
239 |
+
B, N, C = x.shape
|
240 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
241 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
242 |
+
|
243 |
+
q = self.q_norm(q)
|
244 |
+
k = self.k_norm(k)
|
245 |
+
|
246 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
247 |
+
|
248 |
+
if mask is not None:
|
249 |
+
mask = mask.unsqueeze(1) # Unsqueeze for multi-head
|
250 |
+
attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
|
251 |
+
|
252 |
+
if self.allow_zero_attn:
|
253 |
+
attn = softmax1(attn)
|
254 |
+
else:
|
255 |
+
attn = attn.softmax(dim=-1)
|
256 |
+
attn = self.attn_drop(attn)
|
257 |
+
|
258 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
259 |
+
x = self.proj(x)
|
260 |
+
x = self.proj_drop(x)
|
261 |
+
return x
|
262 |
+
|
263 |
+
|
264 |
+
class NormCrossAttention(nn.Module):
|
265 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, norm_layer=nn.LayerNorm, attn_drop=0., proj_drop=0., allow_zero_attn=False):
|
266 |
+
super().__init__()
|
267 |
+
self.num_heads = num_heads
|
268 |
+
head_dim = dim // num_heads
|
269 |
+
self.scale = head_dim ** -0.5
|
270 |
+
self.allow_zero_attn = allow_zero_attn
|
271 |
+
|
272 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
273 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
274 |
+
|
275 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
276 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
277 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
278 |
+
|
279 |
+
self.q_norm = norm_layer(head_dim)
|
280 |
+
self.k_norm = norm_layer(head_dim)
|
281 |
+
|
282 |
+
def forward(self, x, context, mask=None):
|
283 |
+
B, N, C = x.shape
|
284 |
+
_, M, _ = context.shape
|
285 |
+
|
286 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
287 |
+
kv = self.kv(context).reshape(B, M, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
288 |
+
k, v = kv[0], kv[1]
|
289 |
+
|
290 |
+
q = self.q_norm(q)
|
291 |
+
k = self.k_norm(k)
|
292 |
+
|
293 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
294 |
+
if mask is not None:
|
295 |
+
mask = rearrange(mask, "b n m -> b 1 n m") # Unsqueeze / reshape for multi-head
|
296 |
+
attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
|
297 |
+
|
298 |
+
if self.allow_zero_attn:
|
299 |
+
attn = softmax1(attn)
|
300 |
+
else:
|
301 |
+
attn = attn.softmax(dim=-1)
|
302 |
+
attn = self.attn_drop(attn)
|
303 |
+
|
304 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
305 |
+
x = self.proj(x)
|
306 |
+
x = self.proj_drop(x)
|
307 |
+
return x
|
308 |
+
|
309 |
+
|
310 |
+
class Block(nn.Module):
|
311 |
+
|
312 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, proj_bias=True, mlp_bias=True, drop=0., attn_drop=0.,
|
313 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gated_mlp=False, qk_norm=False, allow_zero_attn=False):
|
314 |
+
super().__init__()
|
315 |
+
self.norm1 = norm_layer(dim)
|
316 |
+
|
317 |
+
if not qk_norm:
|
318 |
+
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
|
319 |
+
else:
|
320 |
+
self.attn = NormAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, norm_layer=norm_layer, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
|
321 |
+
|
322 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
323 |
+
self.norm2 = norm_layer(dim)
|
324 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
325 |
+
|
326 |
+
if not gated_mlp:
|
327 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias, drop=drop)
|
328 |
+
else:
|
329 |
+
self.mlp = GatedMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias)
|
330 |
+
|
331 |
+
def forward(self, x, mask=None):
|
332 |
+
x = x + self.drop_path(self.attn(self.norm1(x), mask))
|
333 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
334 |
+
return x
|
335 |
+
|
336 |
+
|
337 |
+
class DecoderBlock(nn.Module):
|
338 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, proj_bias=True, mlp_bias=True, drop=0., attn_drop=0.,
|
339 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gated_mlp=False, qk_norm=False, allow_zero_attn=False):
|
340 |
+
super().__init__()
|
341 |
+
self.norm1 = norm_layer(dim)
|
342 |
+
|
343 |
+
if not qk_norm:
|
344 |
+
self.self_attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
|
345 |
+
self.cross_attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
|
346 |
+
else:
|
347 |
+
self.self_attn = NormAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, norm_layer=norm_layer, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
|
348 |
+
self.cross_attn = NormCrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, norm_layer=norm_layer, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
|
349 |
+
|
350 |
+
|
351 |
+
self.query_norm = norm_layer(dim)
|
352 |
+
self.context_norm = norm_layer(dim)
|
353 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
354 |
+
self.norm2 = norm_layer(dim)
|
355 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
356 |
+
|
357 |
+
if not gated_mlp:
|
358 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias, drop=drop)
|
359 |
+
else:
|
360 |
+
self.mlp = GatedMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias)
|
361 |
+
|
362 |
+
def forward(self, x, context, sa_mask=None, xa_mask=None):
|
363 |
+
x = x + self.drop_path(self.self_attn(self.norm1(x), sa_mask))
|
364 |
+
x = x + self.drop_path(self.cross_attn(self.query_norm(x), self.context_norm(context), xa_mask))
|
365 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
366 |
+
return x
|
367 |
+
|
368 |
+
|
369 |
+
class CrossAttentionBlock(nn.Module):
|
370 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
371 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gated_mlp=False, allow_zero_attn=False):
|
372 |
+
super().__init__()
|
373 |
+
self.cross_attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
|
374 |
+
self.query_norm = norm_layer(dim)
|
375 |
+
self.context_norm = norm_layer(dim)
|
376 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
377 |
+
self.norm2 = norm_layer(dim)
|
378 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
379 |
+
if not gated_mlp:
|
380 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
381 |
+
else:
|
382 |
+
self.mlp = GatedMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
|
383 |
+
|
384 |
+
def forward(self, x, context, xa_mask=None, **kwargs):
|
385 |
+
x = x + self.drop_path(self.cross_attn(self.query_norm(x), self.context_norm(context), xa_mask))
|
386 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
387 |
+
return x
|
fourm/models/fm_vit.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
import copy
|
16 |
+
from functools import partial
|
17 |
+
from typing import Optional, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
from fourm.utils.timm.registry import register_model
|
23 |
+
from huggingface_hub import PyTorchModelHubMixin
|
24 |
+
|
25 |
+
from .encoder_embeddings import ImageEncoderEmbedding
|
26 |
+
from .fm_utils import Block, LayerNorm
|
27 |
+
from fourm.data.modality_info import MODALITY_INFO
|
28 |
+
|
29 |
+
|
30 |
+
__all__ = [
|
31 |
+
# GELU models
|
32 |
+
'fm_vit_tiny_6e_gelu',
|
33 |
+
'fm_vit_small_8e_gelu',
|
34 |
+
'fm_vit_base_12e_gelu',
|
35 |
+
'fm_vit_large_24e_gelu',
|
36 |
+
'fm_vit_xlarge_24e_gelu',
|
37 |
+
# SwiGLU models
|
38 |
+
'fm_vit_tiny_6e_swiglu_nobias',
|
39 |
+
'fm_vit_small_8e_swiglu_nobias',
|
40 |
+
'fm_vit_base_12e_swiglu_nobias',
|
41 |
+
'fm_vit_large_24e_swiglu_nobias',
|
42 |
+
'fm_vit_xlarge_24e_swiglu_nobias',
|
43 |
+
# SwiGLU + QKNorm models
|
44 |
+
'fm_vit_base_12e_swiglu_qknorm_nobias',
|
45 |
+
'fm_vit_large_24e_swiglu_qknorm_nobias',
|
46 |
+
'fm_vit_xlarge_24e_swiglu_qknorm_nobias',
|
47 |
+
]
|
48 |
+
|
49 |
+
class FourMViT(nn.Module):
|
50 |
+
"""Modified 4M model, adapted to behave as a simple RGB-only ViT.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
img_size (int): Input image size.
|
54 |
+
patch_size (int): Patch size.
|
55 |
+
in_chans (int): Number of input image channels.
|
56 |
+
dim (int): Patch embedding dimension.
|
57 |
+
encoder_depth (int): Depth of ViT / number of encoder blocks.
|
58 |
+
num_heads (int): Number of attention heads in each ViT block.
|
59 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
60 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
61 |
+
proj_bias (bool): If True, adds a bias to the attention out proj layer.
|
62 |
+
mlp_bias (bool): If True, adds a learnable bias for the feedforward.
|
63 |
+
drop_path_rate (float): Stochastic depth rate.
|
64 |
+
drop_rate (float): Dropout rate.
|
65 |
+
attn_drop_rate (float): Attention dropout rate.
|
66 |
+
act_layer (nn.Module): Activation layer.
|
67 |
+
norm_layer (nn.Module): Normalization layer.
|
68 |
+
gated_mlp (bool): If True, makes the feedforward gated (e.g., for SwiGLU)
|
69 |
+
qk_norm (bool): If True, normalizes the query and keys (as in ViT-22B)
|
70 |
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
71 |
+
encoder_norm (bool): If True, adds a norm layer after the last encoder block.
|
72 |
+
output_head (Optional[nn.Module]): Optional output head after the encoder
|
73 |
+
"""
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
img_size=224,
|
77 |
+
patch_size=16,
|
78 |
+
in_chans=3,
|
79 |
+
dim=768,
|
80 |
+
encoder_depth=12,
|
81 |
+
num_heads=12,
|
82 |
+
mlp_ratio=4.0,
|
83 |
+
qkv_bias: bool = True,
|
84 |
+
proj_bias: bool = True,
|
85 |
+
mlp_bias: bool = True,
|
86 |
+
drop_path_rate: float =0.0,
|
87 |
+
drop_rate: float = 0.0,
|
88 |
+
attn_drop_rate: float =0.0,
|
89 |
+
act_layer: torch.Tensor =nn.GELU,
|
90 |
+
norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6),
|
91 |
+
gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU
|
92 |
+
qk_norm: bool = False,
|
93 |
+
encoder_norm = True,
|
94 |
+
output_head: Optional[nn.Module] = None,
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.img_size = img_size
|
98 |
+
self.init_std = 0.02
|
99 |
+
rgb_embedding = ImageEncoderEmbedding(num_channels=in_chans, patch_size=patch_size,
|
100 |
+
dim_tokens=dim, sincos_pos_emb=True, image_size=img_size)
|
101 |
+
self.num_patches = rgb_embedding.num_patches
|
102 |
+
self.encoder_embeddings = nn.ModuleDict({f"rgb@{img_size}": rgb_embedding})
|
103 |
+
|
104 |
+
# stochastic depth decay rule
|
105 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)]
|
106 |
+
|
107 |
+
self.encoder = nn.ModuleList([
|
108 |
+
Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
|
109 |
+
drop_path=dpr[i], drop=drop_rate, attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer,
|
110 |
+
gated_mlp=gated_mlp, qk_norm=qk_norm)
|
111 |
+
for i in range(encoder_depth)
|
112 |
+
])
|
113 |
+
|
114 |
+
self.encoder_norm = norm_layer(dim) if encoder_norm else nn.Identity()
|
115 |
+
|
116 |
+
# Weight init
|
117 |
+
self.init_weights()
|
118 |
+
|
119 |
+
# Classification head is initialized after init_weights() to allow for special init scale
|
120 |
+
if output_head is not None:
|
121 |
+
self.output_head = output_head
|
122 |
+
if hasattr(self.output_head, 'init'):
|
123 |
+
self.output_head.init(dim)
|
124 |
+
else:
|
125 |
+
self.output_head = nn.Identity()
|
126 |
+
|
127 |
+
def init_weights(self):
|
128 |
+
"""Weight initialization following MAE's initialization scheme"""
|
129 |
+
|
130 |
+
for name, m in self.named_modules():
|
131 |
+
# Skipping tokenizers to avoid reinitializing them
|
132 |
+
if "tokenizer" in name:
|
133 |
+
continue
|
134 |
+
# Linear
|
135 |
+
elif isinstance(m, nn.Linear):
|
136 |
+
if 'qkv' in name:
|
137 |
+
# treat the weights of Q, K, V separately
|
138 |
+
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
|
139 |
+
nn.init.uniform_(m.weight, -val, val)
|
140 |
+
elif 'kv' in name:
|
141 |
+
# treat the weights of K, V separately
|
142 |
+
val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
|
143 |
+
nn.init.uniform_(m.weight, -val, val)
|
144 |
+
else:
|
145 |
+
nn.init.xavier_uniform_(m.weight)
|
146 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
147 |
+
nn.init.constant_(m.bias, 0)
|
148 |
+
# LayerNorm
|
149 |
+
elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
|
150 |
+
nn.init.constant_(m.weight, 1.0)
|
151 |
+
nn.init.constant_(m.bias, 0)
|
152 |
+
|
153 |
+
# Embedding
|
154 |
+
elif isinstance(m, nn.Embedding):
|
155 |
+
nn.init.normal_(m.weight, std=self.init_std)
|
156 |
+
# Conv2d
|
157 |
+
elif isinstance(m, nn.Conv2d):
|
158 |
+
if '.proj' in name:
|
159 |
+
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
|
160 |
+
w = m.weight.data
|
161 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
162 |
+
|
163 |
+
def get_num_layers_encoder(self):
|
164 |
+
return len(self.encoder)
|
165 |
+
|
166 |
+
def get_num_layers(self):
|
167 |
+
return self.get_num_layers_encoder()
|
168 |
+
|
169 |
+
@torch.jit.ignore
|
170 |
+
def no_weight_decay(self):
|
171 |
+
no_wd_set = set()
|
172 |
+
|
173 |
+
for mod, emb_module in self.encoder_embeddings.items():
|
174 |
+
if hasattr(emb_module, 'no_weight_decay'):
|
175 |
+
to_skip = emb_module.no_weight_decay()
|
176 |
+
to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip])
|
177 |
+
no_wd_set = no_wd_set | to_skip
|
178 |
+
|
179 |
+
return no_wd_set
|
180 |
+
|
181 |
+
|
182 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
183 |
+
"""
|
184 |
+
Forward pass of the model.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
x (torch.Tensor): Input tensor. Shape (B, C, H, W)
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
torch.Tensor: Output tensor. Shape (B, num_classes).
|
191 |
+
"""
|
192 |
+
rgb_dict = {'tensor': x}
|
193 |
+
rgb_dict = self.encoder_embeddings[f'rgb@{self.img_size}'](rgb_dict)
|
194 |
+
|
195 |
+
# Add embeddings to patchified RGB image
|
196 |
+
x = rgb_dict['x'] + rgb_dict['emb'] # Shape: (B, N, D) with N = num_patches
|
197 |
+
|
198 |
+
for blk in self.encoder:
|
199 |
+
x = blk(x)
|
200 |
+
|
201 |
+
x = self.encoder_norm(x) # Shape: (B, N, D)
|
202 |
+
|
203 |
+
out = self.output_head(x)
|
204 |
+
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
+
def freeze_encoder(self, freeze_embeddings=True):
|
209 |
+
for param in self.encoder.parameters():
|
210 |
+
param.requires_grad = False
|
211 |
+
|
212 |
+
for param in self.encoder_norm.parameters():
|
213 |
+
param.requires_grad = False
|
214 |
+
|
215 |
+
if freeze_embeddings:
|
216 |
+
for param in self.encoder_embeddings.parameters():
|
217 |
+
param.requires_grad = False
|
218 |
+
|
219 |
+
def unfreeze_encoder(self, unfreeze_embeddings=True):
|
220 |
+
for param in self.encoder.parameters():
|
221 |
+
param.requires_grad = True
|
222 |
+
|
223 |
+
for param in self.encoder_norm.parameters():
|
224 |
+
param.requires_grad = True
|
225 |
+
|
226 |
+
if unfreeze_embeddings:
|
227 |
+
for param in self.encoder_embeddings.parameters():
|
228 |
+
param.requires_grad = True
|
229 |
+
|
230 |
+
|
231 |
+
################################################
|
232 |
+
|
233 |
+
# Wrapper for easy loading with Huggingface Hub
|
234 |
+
|
235 |
+
class FMViT(FourMViT, PyTorchModelHubMixin):
|
236 |
+
"""Wrapper around FourMViT for easy loading with Huggingface Hub.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
config (dict): Dictionary containing the model and modality configuration,
|
240 |
+
used for loading from Huggingface Hub.
|
241 |
+
output_head (nn.Module): Optional output head.
|
242 |
+
"""
|
243 |
+
def __init__(self, config: dict, output_head: Optional[nn.Module] = None):
|
244 |
+
|
245 |
+
config = copy.deepcopy(config)
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias'])
|
250 |
+
config['act_layer'] = getattr(torch.nn, config['act_layer'])
|
251 |
+
|
252 |
+
img_size = config['image_size']
|
253 |
+
config['img_size'] = img_size
|
254 |
+
config['patch_size'] = MODALITY_INFO[f'rgb@{img_size}'].get('patch_size', config['patch_size'])
|
255 |
+
config['in_chans'] = MODALITY_INFO[f'rgb@{img_size}'].get('num_channels', 3)
|
256 |
+
|
257 |
+
for key in ['image_size', 'norm_bias', 'domains_in', 'domains_out', 'decoder_depth', 'share_modality_embeddings']:
|
258 |
+
if key in config:
|
259 |
+
del config[key]
|
260 |
+
|
261 |
+
super().__init__(
|
262 |
+
output_head=output_head,
|
263 |
+
**config
|
264 |
+
)
|
265 |
+
|
266 |
+
|
267 |
+
################################################
|
268 |
+
|
269 |
+
# Model definitions
|
270 |
+
|
271 |
+
# GELU variants
|
272 |
+
@register_model
|
273 |
+
def fm_vit_tiny_6e_gelu(**kwargs):
|
274 |
+
model = FourMViT(
|
275 |
+
encoder_depth=6,
|
276 |
+
dim=384,
|
277 |
+
num_heads=6,
|
278 |
+
mlp_ratio=4,
|
279 |
+
qkv_bias=True,
|
280 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
281 |
+
**kwargs
|
282 |
+
)
|
283 |
+
return model
|
284 |
+
|
285 |
+
|
286 |
+
@register_model
|
287 |
+
def fm_vit_small_8e_gelu(**kwargs):
|
288 |
+
model = FourMViT(
|
289 |
+
encoder_depth=8,
|
290 |
+
dim=512,
|
291 |
+
num_heads=8,
|
292 |
+
mlp_ratio=4,
|
293 |
+
qkv_bias=True,
|
294 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
295 |
+
**kwargs
|
296 |
+
)
|
297 |
+
return model
|
298 |
+
|
299 |
+
|
300 |
+
@register_model
|
301 |
+
def fm_vit_base_12e_gelu(**kwargs):
|
302 |
+
model = FourMViT(
|
303 |
+
encoder_depth=12,
|
304 |
+
dim=768,
|
305 |
+
num_heads=12,
|
306 |
+
mlp_ratio=4,
|
307 |
+
qkv_bias=True,
|
308 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
309 |
+
**kwargs
|
310 |
+
)
|
311 |
+
return model
|
312 |
+
|
313 |
+
|
314 |
+
@register_model
|
315 |
+
def fm_vit_large_24e_gelu(**kwargs):
|
316 |
+
model = FourMViT(
|
317 |
+
encoder_depth=24,
|
318 |
+
dim=1024,
|
319 |
+
num_heads=16,
|
320 |
+
mlp_ratio=4,
|
321 |
+
qkv_bias=True,
|
322 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
323 |
+
**kwargs
|
324 |
+
)
|
325 |
+
return model
|
326 |
+
|
327 |
+
@register_model
|
328 |
+
def fm_vit_xlarge_24e_gelu(**kwargs):
|
329 |
+
model = FourMViT(
|
330 |
+
encoder_depth=24,
|
331 |
+
dim=2048,
|
332 |
+
num_heads=32,
|
333 |
+
mlp_ratio=4,
|
334 |
+
qkv_bias=True,
|
335 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
336 |
+
**kwargs
|
337 |
+
)
|
338 |
+
return model
|
339 |
+
|
340 |
+
|
341 |
+
# SwiGLU variants
|
342 |
+
@register_model
|
343 |
+
def fm_vit_tiny_6e_swiglu_nobias(**kwargs):
|
344 |
+
model = FourMViT(
|
345 |
+
encoder_depth=6,
|
346 |
+
dim=384,
|
347 |
+
num_heads=6,
|
348 |
+
mlp_ratio=4,
|
349 |
+
qkv_bias=False,
|
350 |
+
proj_bias=False,
|
351 |
+
mlp_bias=False,
|
352 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
353 |
+
act_layer=nn.SiLU,
|
354 |
+
gated_mlp=True,
|
355 |
+
**kwargs
|
356 |
+
)
|
357 |
+
return model
|
358 |
+
|
359 |
+
|
360 |
+
@register_model
|
361 |
+
def fm_vit_small_8e_swiglu_nobias(**kwargs):
|
362 |
+
model = FourMViT(
|
363 |
+
encoder_depth=8,
|
364 |
+
dim=512,
|
365 |
+
num_heads=8,
|
366 |
+
mlp_ratio=4,
|
367 |
+
qkv_bias=False,
|
368 |
+
proj_bias=False,
|
369 |
+
mlp_bias=False,
|
370 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
371 |
+
act_layer=nn.SiLU,
|
372 |
+
gated_mlp=True,
|
373 |
+
**kwargs
|
374 |
+
)
|
375 |
+
return model
|
376 |
+
|
377 |
+
|
378 |
+
@register_model
|
379 |
+
def fm_vit_base_12e_swiglu_nobias(**kwargs):
|
380 |
+
model = FourMViT(
|
381 |
+
encoder_depth=12,
|
382 |
+
dim=768,
|
383 |
+
num_heads=12,
|
384 |
+
mlp_ratio=4,
|
385 |
+
qkv_bias=False,
|
386 |
+
proj_bias=False,
|
387 |
+
mlp_bias=False,
|
388 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
389 |
+
act_layer=nn.SiLU,
|
390 |
+
gated_mlp=True,
|
391 |
+
**kwargs
|
392 |
+
)
|
393 |
+
return model
|
394 |
+
|
395 |
+
|
396 |
+
@register_model
|
397 |
+
def fm_vit_large_24e_swiglu_nobias(**kwargs):
|
398 |
+
model = FourMViT(
|
399 |
+
encoder_depth=24,
|
400 |
+
dim=1024,
|
401 |
+
num_heads=16,
|
402 |
+
mlp_ratio=4,
|
403 |
+
qkv_bias=False,
|
404 |
+
proj_bias=False,
|
405 |
+
mlp_bias=False,
|
406 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
407 |
+
act_layer=nn.SiLU,
|
408 |
+
gated_mlp=True,
|
409 |
+
**kwargs
|
410 |
+
)
|
411 |
+
return model
|
412 |
+
|
413 |
+
@register_model
|
414 |
+
def fm_vit_xlarge_24e_swiglu_nobias(**kwargs):
|
415 |
+
model = FourMViT(
|
416 |
+
encoder_depth=24,
|
417 |
+
dim=2048,
|
418 |
+
num_heads=32,
|
419 |
+
mlp_ratio=4,
|
420 |
+
qkv_bias=False,
|
421 |
+
proj_bias=False,
|
422 |
+
mlp_bias=False,
|
423 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
424 |
+
act_layer=nn.SiLU,
|
425 |
+
gated_mlp=True,
|
426 |
+
**kwargs
|
427 |
+
)
|
428 |
+
return model
|
429 |
+
|
430 |
+
# SwiGLU + QKNorm variants
|
431 |
+
|
432 |
+
@register_model
|
433 |
+
def fm_vit_base_12e_swiglu_qknorm_nobias(**kwargs):
|
434 |
+
model = FourMViT(
|
435 |
+
encoder_depth=12,
|
436 |
+
dim=768,
|
437 |
+
num_heads=12,
|
438 |
+
mlp_ratio=4,
|
439 |
+
qkv_bias=False,
|
440 |
+
proj_bias=False,
|
441 |
+
mlp_bias=False,
|
442 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
443 |
+
act_layer=nn.SiLU,
|
444 |
+
gated_mlp=True,
|
445 |
+
qk_norm=True,
|
446 |
+
**kwargs
|
447 |
+
)
|
448 |
+
return model
|
449 |
+
|
450 |
+
|
451 |
+
@register_model
|
452 |
+
def fm_vit_large_24e_swiglu_qknorm_nobias(**kwargs):
|
453 |
+
model = FourMViT(
|
454 |
+
encoder_depth=24,
|
455 |
+
dim=1024,
|
456 |
+
num_heads=16,
|
457 |
+
mlp_ratio=4,
|
458 |
+
qkv_bias=False,
|
459 |
+
proj_bias=False,
|
460 |
+
mlp_bias=False,
|
461 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
462 |
+
act_layer=nn.SiLU,
|
463 |
+
gated_mlp=True,
|
464 |
+
qk_norm=True,
|
465 |
+
**kwargs
|
466 |
+
)
|
467 |
+
return model
|
468 |
+
|
469 |
+
@register_model
|
470 |
+
def fm_vit_xlarge_24e_swiglu_qknorm_nobias(**kwargs):
|
471 |
+
model = FourMViT(
|
472 |
+
encoder_depth=24,
|
473 |
+
dim=2048,
|
474 |
+
num_heads=32,
|
475 |
+
mlp_ratio=4,
|
476 |
+
qkv_bias=False,
|
477 |
+
proj_bias=False,
|
478 |
+
mlp_bias=False,
|
479 |
+
norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
|
480 |
+
act_layer=nn.SiLU,
|
481 |
+
gated_mlp=True,
|
482 |
+
qk_norm=True,
|
483 |
+
**kwargs
|
484 |
+
)
|
485 |
+
return model
|
fourm/models/generate.py
ADDED
@@ -0,0 +1,1273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from collections import defaultdict
|
15 |
+
from typing import Union, List, Optional
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from einops import rearrange, repeat
|
20 |
+
from torch import nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from fourm.utils import get_sentinel_to_id_mapping, merge_span_masking
|
24 |
+
from fourm.utils.generation import cosine_schedule, linear_schedule, onex_temp_schedule, linear_temp_schedule, continue_schedule
|
25 |
+
from tqdm import tqdm
|
26 |
+
import copy
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def empty_img_modality(mod_dict, key):
|
31 |
+
# Input mask
|
32 |
+
mod_dict[key]['input_mask'][:] = True
|
33 |
+
|
34 |
+
# Target Mask
|
35 |
+
mod_dict[key]['target_mask'][:] = False
|
36 |
+
|
37 |
+
return mod_dict
|
38 |
+
|
39 |
+
def empty_seq_modality(mod_dict, key, s1_id=5):
|
40 |
+
# To create an empty sequence, we suppose an input budget of 1, and the rest assigned to targets
|
41 |
+
|
42 |
+
# Input tensor
|
43 |
+
# Input is [S_1], target is [S_1] ...... [S_2]
|
44 |
+
# (so [S_1] [S_1] ..... [S_2] when combined)
|
45 |
+
mod_dict[key]['tensor'][:] = 0
|
46 |
+
mod_dict[key]['tensor'][:,[0,1]] = s1_id # s1_id is id of the first sentinel token ([S_1])
|
47 |
+
mod_dict[key]['tensor'][:,-1] = s1_id + 1
|
48 |
+
|
49 |
+
# Input mask
|
50 |
+
# Set first token to input (i.e. 0), rest to target (i.e. 1)
|
51 |
+
mod_dict[key]['input_mask'][:] = True
|
52 |
+
mod_dict[key]['input_mask'][:,0] = False
|
53 |
+
|
54 |
+
# Target Mask
|
55 |
+
mod_dict[key]['target_mask'] = ~mod_dict[key]['input_mask']
|
56 |
+
|
57 |
+
# Decoder attn mask
|
58 |
+
# WARNING: Not needed / used in GenerationSampler, where causal mask is enforced
|
59 |
+
# First token is input, not part of target
|
60 |
+
mod_dict[key]['decoder_attention_mask'][:] = 1
|
61 |
+
mod_dict[key]['decoder_attention_mask'][:, 0] = 0
|
62 |
+
|
63 |
+
return mod_dict
|
64 |
+
|
65 |
+
def empty_seq_emb_modality(mod_dict, key):
|
66 |
+
# Tensor
|
67 |
+
mod_dict[key]['tensor'] = torch.zeros_like(mod_dict[key]['tensor'])
|
68 |
+
|
69 |
+
# Input mask
|
70 |
+
mod_dict[key]['input_mask'] = torch.ones_like(mod_dict[key]['input_mask'])
|
71 |
+
# It is crucial to specify the input mask as such, CFG won't work otherwise!
|
72 |
+
mod_dict[key]['input_mask'][:, 0] = False
|
73 |
+
|
74 |
+
# Target Mask
|
75 |
+
mod_dict[key]['target_mask'] = torch.ones_like(mod_dict[key]['target_mask'])
|
76 |
+
|
77 |
+
# Decoder attn mask
|
78 |
+
mod_dict[key]['decoder_attention_mask'][:] = False
|
79 |
+
|
80 |
+
return mod_dict
|
81 |
+
|
82 |
+
|
83 |
+
def init_empty_target_modality(mod_dict, modality_info, domain, batch_size, num_tokens, device):
|
84 |
+
"""
|
85 |
+
Initializes an empty target modality dictionary for a given domain.
|
86 |
+
Used to initialize target modality dictionaries for generation.
|
87 |
+
"""
|
88 |
+
if modality_info[domain]['type'] == 'img':
|
89 |
+
# Initialize mod dict
|
90 |
+
mod_dict[domain] = {
|
91 |
+
'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int64, device=device),
|
92 |
+
'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device),
|
93 |
+
'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device),
|
94 |
+
}
|
95 |
+
# Set it to the correct values
|
96 |
+
mod_dict = empty_img_modality(mod_dict, domain)
|
97 |
+
|
98 |
+
elif modality_info[domain]['type'] in ['seq', 'seq_token', 'seq_emb']:
|
99 |
+
# Initialize mod dict
|
100 |
+
num_tokens = max(num_tokens, 2)
|
101 |
+
mod_dict[domain] = {
|
102 |
+
'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int32, device=device),
|
103 |
+
'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device),
|
104 |
+
'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device),
|
105 |
+
'decoder_attention_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device),
|
106 |
+
}
|
107 |
+
# Set it to the correct values
|
108 |
+
if modality_info[domain]['type'] in ['seq', 'seq_token']:
|
109 |
+
mod_dict = empty_seq_modality(mod_dict, domain)
|
110 |
+
elif modality_info[domain]['type'] == 'seq_emb':
|
111 |
+
mod_dict = empty_seq_emb_modality(mod_dict, domain)
|
112 |
+
else:
|
113 |
+
raise ValueError()
|
114 |
+
|
115 |
+
return mod_dict
|
116 |
+
|
117 |
+
def init_full_input_modality(mod_dict, modality_info, domain, device, eos_id=3):
|
118 |
+
if domain.startswith('rgb'):
|
119 |
+
batch_size, _, H, W = mod_dict[domain]['tensor'].shape
|
120 |
+
patch_size = modality_info[domain]['patch_size']
|
121 |
+
num_tokens = (H // patch_size) * (W // patch_size)
|
122 |
+
shape = (batch_size, num_tokens)
|
123 |
+
else:
|
124 |
+
shape = mod_dict[domain]['tensor'].shape
|
125 |
+
if 'input_mask' not in mod_dict[domain]:
|
126 |
+
mod_dict[domain]['input_mask'] = torch.zeros(shape, dtype=torch.bool, device=device)
|
127 |
+
if 'target_mask' not in mod_dict[domain]:
|
128 |
+
mod_dict[domain]['target_mask'] = torch.ones(shape, dtype=torch.bool, device=device)
|
129 |
+
if 'decoder_attention_mask' not in mod_dict[domain]:
|
130 |
+
mod_dict[domain]['decoder_attention_mask'] = torch.zeros(shape, dtype=torch.bool, device=device)
|
131 |
+
|
132 |
+
if modality_info[domain]['type'] == 'img':
|
133 |
+
mod_dict[domain]['input_mask'][:] = False
|
134 |
+
mod_dict[domain]['target_mask'][:] = True
|
135 |
+
|
136 |
+
elif modality_info[domain]['type'] in ['seq', 'seq_token']:
|
137 |
+
if eos_id in mod_dict[domain]['tensor']:
|
138 |
+
eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item()
|
139 |
+
else:
|
140 |
+
mod_dict[domain]['tensor'][:,0] = eos_id
|
141 |
+
eos_idx = 0
|
142 |
+
mod_dict[domain]['input_mask'][:,:eos_idx+1] = False
|
143 |
+
mod_dict[domain]['input_mask'][:,eos_idx+1:] = True
|
144 |
+
mod_dict[domain]['target_mask'][:] = True
|
145 |
+
|
146 |
+
elif modality_info[domain]['type'] in ['seq_emb']:
|
147 |
+
# T5 caption has the valid mask saved alongside the embeddings
|
148 |
+
mod_dict[domain]['input_mask'] = ~mod_dict[domain]['mask_valid']
|
149 |
+
mod_dict[domain]['target_mask'] = torch.ones_like(mod_dict[domain]['mask_valid'])
|
150 |
+
mod_dict[domain]['decoder_attention_mask'] = torch.zeros_like(mod_dict[domain]['mask_valid'])
|
151 |
+
|
152 |
+
return mod_dict
|
153 |
+
|
154 |
+
def custom_text(sample, input_text, eos_token, key, device, text_tokenizer, target_max_len=50, start_token="[S_1]"):
|
155 |
+
input_ids = text_tokenizer.encode(input_text).ids
|
156 |
+
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
157 |
+
|
158 |
+
target_text = [start_token]
|
159 |
+
target_text.extend(["[PAD]"] * (target_max_len - 2))
|
160 |
+
target_text.append(eos_token)
|
161 |
+
target_text = " ".join(target_text)
|
162 |
+
target_ids = text_tokenizer.encode(target_text).ids
|
163 |
+
target_ids = torch.tensor(target_ids).unsqueeze(0)
|
164 |
+
|
165 |
+
all_ids = torch.cat([input_ids, target_ids], dim=1)
|
166 |
+
|
167 |
+
input_mask = torch.cat([
|
168 |
+
torch.zeros_like(input_ids, dtype=torch.bool),
|
169 |
+
torch.ones_like(target_ids, dtype=torch.bool),
|
170 |
+
], dim=1)
|
171 |
+
|
172 |
+
target_mask = torch.cat([
|
173 |
+
torch.ones_like(input_ids, dtype=torch.bool),
|
174 |
+
torch.zeros_like(target_ids, dtype=torch.bool),
|
175 |
+
], dim=1)
|
176 |
+
|
177 |
+
sample[key] = {}
|
178 |
+
sample[key]['tensor'] = all_ids.to(device)
|
179 |
+
sample[key]['input_mask'] = input_mask.to(device)
|
180 |
+
sample[key]['target_mask'] = target_mask.to(device)
|
181 |
+
sample[key]['decoder_attention_mask'] = torch.zeros(all_ids.shape, dtype=torch.bool, device=device)
|
182 |
+
|
183 |
+
return sample
|
184 |
+
|
185 |
+
def expand_to_batch(mod_dict, batch_size):
|
186 |
+
for mod, d in mod_dict.items():
|
187 |
+
for k, v in d.items():
|
188 |
+
if k in ['tensor', 'input_mask', 'target_mask', 'decoder_attention_mask', 'mask_valid']:
|
189 |
+
B = v.shape[0]
|
190 |
+
if B == 1:
|
191 |
+
mod_dict[mod][k] = repeat(v, "1 ... -> b ...", b=batch_size)
|
192 |
+
elif B != batch_size:
|
193 |
+
raise ValueError(f"Invalid batch size: {B} instead of {batch_size}")
|
194 |
+
|
195 |
+
return mod_dict
|
196 |
+
|
197 |
+
def build_chained_generation_schedules(
|
198 |
+
cond_domains: List[str],
|
199 |
+
target_domains: List[str],
|
200 |
+
tokens_per_target: List[int],
|
201 |
+
autoregression_schemes: List[str],
|
202 |
+
decoding_steps: List[int],
|
203 |
+
token_decoding_schedules: List[str],
|
204 |
+
temps: List[float],
|
205 |
+
temp_schedules: List[float],
|
206 |
+
cfg_scales: List[float],
|
207 |
+
cfg_schedules: List[str],
|
208 |
+
cfg_grow_conditioning: bool = False,
|
209 |
+
modality_info: Optional[dict] = None,
|
210 |
+
):
|
211 |
+
"""
|
212 |
+
Builds a list of chained generation schedules, where each schedule is a tuple of the form:
|
213 |
+
(target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains)
|
214 |
+
|
215 |
+
Args:
|
216 |
+
cond_domains: List of conditioning domains
|
217 |
+
target_domains: List of target domains
|
218 |
+
tokens_per_target: List of number of tokens to decode for each target domain
|
219 |
+
autoregression_schemes: List of autoregression schemes for each target domain. maskgit, roar, or autoregressive
|
220 |
+
decoding_steps: List of number of maskgit steps for each target domain (if applicable)
|
221 |
+
token_decoding_schedules: List of maskgit token schedules for each target domain (if applicable). cosine or linear
|
222 |
+
temps: List of starting temperatures for each target domain
|
223 |
+
temp_schedules: List of temperature schedules for each target domain. linear, constant, or onex:{min_t}:{power}
|
224 |
+
cfg_scales: List of classifier-free guidance scales for each target domain
|
225 |
+
cfg_schedules: List of classifier-free guidance schedules for each target domain. constant or cosine
|
226 |
+
cfg_grow_conditioning: After every completed modality, add them to classifier-free guidance conditioning
|
227 |
+
modality_info: Dictionary with metadata for each modality, optionally used to verify that the schedule is compatible with the modality
|
228 |
+
"""
|
229 |
+
|
230 |
+
# List of {target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains} dicts
|
231 |
+
chained_schedules = []
|
232 |
+
|
233 |
+
cond_domains = cond_domains.copy()
|
234 |
+
|
235 |
+
for target_idx in range(len(target_domains)):
|
236 |
+
|
237 |
+
scheme = autoregression_schemes[target_idx]
|
238 |
+
target_domain = target_domains[target_idx]
|
239 |
+
ntoks = tokens_per_target[target_idx]
|
240 |
+
maskgit_token_schedule_name = token_decoding_schedules[target_idx]
|
241 |
+
temp = temps[target_idx]
|
242 |
+
temp_schedule_name = temp_schedules[target_idx]
|
243 |
+
cfg_scale = cfg_scales[target_idx]
|
244 |
+
cfg_schedule_name = cfg_schedules[target_idx]
|
245 |
+
|
246 |
+
# Auto-regressive (caption, detection, ...)
|
247 |
+
if scheme == 'autoregressive':
|
248 |
+
chained_schedules.append({
|
249 |
+
'target_domain': target_domain,
|
250 |
+
'scheme': scheme,
|
251 |
+
'num_tokens': None,
|
252 |
+
'temperature': temp,
|
253 |
+
'cfg_scale': cfg_scale,
|
254 |
+
'cfg_cond_domains': cond_domains.copy()
|
255 |
+
})
|
256 |
+
continue
|
257 |
+
|
258 |
+
# Use modality info for (optional) assert if provided
|
259 |
+
if modality_info is not None:
|
260 |
+
assert modality_info[target_domain]['type'] not in ['seq', 'seq_token'], f'Illegal autoregressive scheme {scheme} for target domain {target_domain}'
|
261 |
+
|
262 |
+
# Token schedule
|
263 |
+
if scheme == 'maskgit':
|
264 |
+
# MaskGIT token schedule setup
|
265 |
+
num_steps = decoding_steps[target_idx]
|
266 |
+
if maskgit_token_schedule_name == 'cosine':
|
267 |
+
token_schedule = cosine_schedule(num_steps, (ntoks))
|
268 |
+
elif maskgit_token_schedule_name == 'linear':
|
269 |
+
token_schedule = linear_schedule(num_steps, (ntoks))
|
270 |
+
else:
|
271 |
+
raise ValueError(f'Illegal MaskGIT token schedule {maskgit_token_schedule_name}')
|
272 |
+
elif scheme == 'roar':
|
273 |
+
# ROAR token schedule setup (one-by-one, but random order)
|
274 |
+
num_steps = decoding_steps[target_idx]
|
275 |
+
token_schedule = linear_schedule(num_steps, ntoks)
|
276 |
+
else:
|
277 |
+
raise ValueError(f'Illegal decoding scheme {scheme}')
|
278 |
+
|
279 |
+
# Temperature schedule
|
280 |
+
if temp_schedule_name == 'linear':
|
281 |
+
temp_schedule = linear_temp_schedule(temp, token_schedule)
|
282 |
+
elif temp_schedule_name == 'constant':
|
283 |
+
temp_schedule = temp * np.ones(num_steps)
|
284 |
+
elif 'onex' in temp_schedule_name:
|
285 |
+
# onex temperature schedule has to be formatted like onex:{min_t}:{power}
|
286 |
+
min_t, power = [float(f) for f in temp_schedule_name.split(':')[1:]]
|
287 |
+
temp_schedule = onex_temp_schedule(max_t=temp, min_t=min_t, token_schedule=token_schedule, power=power)
|
288 |
+
else:
|
289 |
+
raise ValueError(f'Illegal temperature schedule {temp_schedule_name}')
|
290 |
+
|
291 |
+
# Classifier-free guidance scale schedule
|
292 |
+
if cfg_schedule_name == 'constant':
|
293 |
+
if isinstance(cfg_scale, float):
|
294 |
+
cfg_schedule = cfg_scale * np.ones(num_steps)
|
295 |
+
elif isinstance(cfg_scale, list):
|
296 |
+
cfg_schedule = np.array(cfg_scale) * np.ones(num_steps).reshape(-1, 1)
|
297 |
+
elif cfg_schedule_name == 'cosine':
|
298 |
+
raise NotImplementedError()
|
299 |
+
else:
|
300 |
+
raise ValueError(f'Illegal guidance schedule {cfg_schedule_name}')
|
301 |
+
|
302 |
+
# Concatenate schedule for this modality with previous ones
|
303 |
+
schedule = [
|
304 |
+
{
|
305 |
+
'target_domain': target_domain,
|
306 |
+
'scheme': scheme,
|
307 |
+
'num_tokens': tok,
|
308 |
+
'temperature': temp,
|
309 |
+
'cfg_scale': cfg,
|
310 |
+
'cfg_cond_domains': cond_domains.copy()
|
311 |
+
}
|
312 |
+
for tok, temp, cfg in zip(token_schedule, temp_schedule, cfg_schedule)
|
313 |
+
]
|
314 |
+
chained_schedules.extend(schedule)
|
315 |
+
|
316 |
+
# Optionally add this new modality to the ones affected by classifier-free guidance
|
317 |
+
if cfg_grow_conditioning:
|
318 |
+
cond_domains.append(target_domain)
|
319 |
+
|
320 |
+
return chained_schedules
|
321 |
+
|
322 |
+
|
323 |
+
class GenerationSampler(nn.Module):
|
324 |
+
"""Sampler that wraps a trained 4M model for generation use cases.
|
325 |
+
Implements standard autoregressive, MaskGIT, and ROAR generation schemes with chaining and weighted guidance."""
|
326 |
+
|
327 |
+
def __init__(self, model):
|
328 |
+
super().__init__()
|
329 |
+
self.model = model
|
330 |
+
|
331 |
+
|
332 |
+
def top_k_top_p_filtering(self, logits, top_k=0.0, top_p=0.0):
|
333 |
+
# Compatible with batching
|
334 |
+
# From https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
335 |
+
if top_k > 0.0:
|
336 |
+
if isinstance(top_k, int):
|
337 |
+
k = min(top_k, logits.shape[-1])
|
338 |
+
elif isinstance(top_k, float):
|
339 |
+
k = min(int(top_k * logits.shape[-1]), logits.shape[-1])
|
340 |
+
else:
|
341 |
+
raise ValueError(f"Invalid value for top_k: {top_k}")
|
342 |
+
|
343 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
344 |
+
indices_to_remove = logits < torch.topk(logits, k)[0][..., -1, None]
|
345 |
+
logits[indices_to_remove] = float("-inf")
|
346 |
+
|
347 |
+
if top_p > 0.0:
|
348 |
+
sorted_logits, sorted_indices = torch.sort(logits, dim=1, descending=True)
|
349 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
350 |
+
sorted_indices_to_remove = cum_probs > top_p
|
351 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
352 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
353 |
+
sorted_indices_to_remove[..., 0] = 0
|
354 |
+
|
355 |
+
restore_indices = torch.argsort(sorted_indices, dim=-1)
|
356 |
+
indices_to_remove = torch.gather(sorted_indices_to_remove, dim=-1, index=restore_indices)
|
357 |
+
logits[indices_to_remove] = float("-inf")
|
358 |
+
|
359 |
+
return logits
|
360 |
+
|
361 |
+
def sample_tokens(self, logits, temperature=1.0, top_k=0.0, top_p=0.0):
|
362 |
+
if np.isclose(temperature, 0, atol=1e-10):
|
363 |
+
samples = torch.argmax(logits, dim=-1)
|
364 |
+
# Since argmax is used, all sampled_probs will be 1 as we're selecting the max probability
|
365 |
+
sampled_probs = torch.ones_like(samples, dtype=torch.float32)
|
366 |
+
else:
|
367 |
+
filtered_logits = self.top_k_top_p_filtering(logits, top_k, top_p)
|
368 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
369 |
+
samples = torch.multinomial(probs, 1)[:, 0]
|
370 |
+
sampled_probs = probs[torch.arange(len(samples)), samples]
|
371 |
+
return samples, sampled_probs
|
372 |
+
|
373 |
+
def sample_tokens_batched(self, logits, temperature=1.0, top_k=0.0, top_p=0.0):
|
374 |
+
if logits.ndim > 2:
|
375 |
+
B, N = logits.shape[0], logits.shape[1]
|
376 |
+
logits = rearrange(logits, 'b n v -> (b n) v')
|
377 |
+
samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p)
|
378 |
+
samples = rearrange(samples, '(b n) -> b n', b=B, n=N)
|
379 |
+
sampled_probs = rearrange(sampled_probs, '(b n) -> b n', b=B, n=N)
|
380 |
+
return samples, sampled_probs
|
381 |
+
else:
|
382 |
+
return self.sample_tokens(logits, temperature, top_k, top_p)
|
383 |
+
|
384 |
+
def select_tokens(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False):
|
385 |
+
samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p)
|
386 |
+
top_indices = torch.topk(sampled_probs, num_select)[1]
|
387 |
+
top_samples = samples[top_indices]
|
388 |
+
if return_all_samples:
|
389 |
+
return top_samples, top_indices, samples
|
390 |
+
else:
|
391 |
+
return top_samples, top_indices
|
392 |
+
|
393 |
+
def select_tokens_batched(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False):
|
394 |
+
if logits.ndim > 2:
|
395 |
+
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k, top_p) # both of shape (B, N)
|
396 |
+
top_indices = torch.topk(sampled_probs, num_select, dim=-1)[1]
|
397 |
+
# Need to switch to gather instead of indexing here
|
398 |
+
top_samples = torch.gather(samples, dim=-1, index=top_indices)
|
399 |
+
if return_all_samples:
|
400 |
+
return top_samples, top_indices, samples
|
401 |
+
else:
|
402 |
+
return top_samples, top_indices
|
403 |
+
else:
|
404 |
+
return self.sample_tokens(logits, num_select, temperature, top_k, top_p, return_all_samples)
|
405 |
+
|
406 |
+
|
407 |
+
def forward_mask_encoder_generation(self, encoder_mod_dict):
|
408 |
+
"""Modification of forward_mask_encoder adapted for generation, with support for batching
|
409 |
+
"""
|
410 |
+
# Form input
|
411 |
+
B = list(encoder_mod_dict.values())[0]['tensor'].shape[0]
|
412 |
+
|
413 |
+
encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.model.cat_encoder_tensors(encoder_mod_dict)
|
414 |
+
# Take max num encoder of tokens (although assuming it's the same everywhere would be better)
|
415 |
+
num_encoder_tokens = (~encoder_mask_all.reshape(B, -1)).sum(dim=1).max()
|
416 |
+
|
417 |
+
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way
|
418 |
+
mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6
|
419 |
+
ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1)
|
420 |
+
# ids_restore = torch.argsort(ids_shuffle, dim=1)
|
421 |
+
ids_keep = ids_shuffle[:, :num_encoder_tokens]
|
422 |
+
|
423 |
+
encoder_tokens = torch.gather(encoder_tokens_all, dim=1,
|
424 |
+
index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2]))
|
425 |
+
encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
|
426 |
+
encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep)
|
427 |
+
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
|
428 |
+
|
429 |
+
if self.model.num_register_tokens > 0:
|
430 |
+
prompt_tokens = repeat(self.prompt_tokens, '() n d -> b n d', b=B)
|
431 |
+
# We add prompt tokens at the beginning of the sequence
|
432 |
+
encoder_tokens = torch.cat([prompt_tokens, encoder_tokens], dim=1)
|
433 |
+
encoder_emb = torch.cat([torch.zeros_like(prompt_tokens), encoder_emb], dim=1)
|
434 |
+
encoder_mask = torch.cat([torch.zeros((B, prompt_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1)
|
435 |
+
mod_mask = torch.cat([torch.full((B, prompt_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1)
|
436 |
+
|
437 |
+
encoder_tokens[encoder_mask] = 0.
|
438 |
+
encoder_emb[encoder_mask] = 0.
|
439 |
+
mod_mask[encoder_mask] = -1
|
440 |
+
# Mask could be of shape 'b n1 n2' but not needed for masked_fill
|
441 |
+
# This means this mask can then be re-used for decoder cross-attention
|
442 |
+
encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2')
|
443 |
+
|
444 |
+
return encoder_tokens, encoder_emb, encoder_mask, mod_mask
|
445 |
+
|
446 |
+
|
447 |
+
def forward_mask_decoder_maskgit(self, mod_dict, target_mod, seed=None):
|
448 |
+
"""Modification of forward_mask_decoder for MaskGIT generation, with support for batching
|
449 |
+
"""
|
450 |
+
if seed is not None:
|
451 |
+
torch.manual_seed(seed)
|
452 |
+
d = mod_dict[target_mod]
|
453 |
+
decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token
|
454 |
+
emb_all = d['emb']
|
455 |
+
decoder_mask_all = d['target_mask']
|
456 |
+
B = decoder_tokens_all.shape[0] # Get batch size
|
457 |
+
mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16)
|
458 |
+
mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0)
|
459 |
+
mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching
|
460 |
+
num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching / Assumes num_decoder_tokens is the same across the batch
|
461 |
+
|
462 |
+
|
463 |
+
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way
|
464 |
+
mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
|
465 |
+
ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1)
|
466 |
+
# ids_restore = torch.argsort(ids_shuffle, dim=1)
|
467 |
+
ids_keep = ids_shuffle[:, :num_decoder_tokens]
|
468 |
+
|
469 |
+
decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2]))
|
470 |
+
decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
|
471 |
+
decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
|
472 |
+
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
|
473 |
+
mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep)
|
474 |
+
|
475 |
+
decoder_tokens[decoder_mask] = 0.
|
476 |
+
decoder_emb[decoder_mask] = 0.
|
477 |
+
mod_mask[decoder_mask] = -1
|
478 |
+
|
479 |
+
return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos
|
480 |
+
|
481 |
+
def forward_mask_decoder_roar(self, mod_dict, target_mod, num_select, seed=None):
|
482 |
+
"""Modification of forward_mask_decoder for ROAR generation, with support for batching
|
483 |
+
"""
|
484 |
+
if seed is not None:
|
485 |
+
torch.manual_seed(seed)
|
486 |
+
d = mod_dict[target_mod]
|
487 |
+
decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token
|
488 |
+
emb_all = d['emb']
|
489 |
+
decoder_mask_all = d['target_mask']
|
490 |
+
B = decoder_tokens_all.shape[0] # Get batch size
|
491 |
+
mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16)
|
492 |
+
mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0)
|
493 |
+
mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching
|
494 |
+
# Only keep the first num_select tokens
|
495 |
+
num_decoder_tokens = min(num_select, (~decoder_mask_all[0]).sum()) # Adapted for batching / Assumes num_decoder_tokens is the same across the batch
|
496 |
+
|
497 |
+
# Add a small random number to the mask so they get sorted in a random way, but keeping the masked tokens first
|
498 |
+
mask_rand = torch.rand(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
|
499 |
+
ids_shuffle = torch.argsort(decoder_mask_all + mask_rand, dim=1)
|
500 |
+
# ids_restore = torch.argsort(ids_shuffle, dim=1)
|
501 |
+
# Only keep the first num_select_tokens
|
502 |
+
ids_keep = ids_shuffle[:, :num_decoder_tokens]
|
503 |
+
|
504 |
+
decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2]))
|
505 |
+
decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
|
506 |
+
decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
|
507 |
+
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
|
508 |
+
mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep)
|
509 |
+
|
510 |
+
decoder_tokens[decoder_mask] = 0.
|
511 |
+
decoder_emb[decoder_mask] = 0.
|
512 |
+
mod_mask[decoder_mask] = -1
|
513 |
+
|
514 |
+
return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos
|
515 |
+
|
516 |
+
def forward_mask_decoder_autoregressive(self, mod_dict, target_mod, seed=None):
|
517 |
+
# Adapted for batching
|
518 |
+
if seed is not None:
|
519 |
+
torch.manual_seed(seed)
|
520 |
+
# This is the concatenation part
|
521 |
+
d = mod_dict[target_mod]
|
522 |
+
decoder_ids_all = d['ids']
|
523 |
+
emb_all = d['emb']
|
524 |
+
decoder_mask_all = d['target_mask']
|
525 |
+
B = decoder_ids_all.shape[0] # Get batch size
|
526 |
+
mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16)
|
527 |
+
mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0)
|
528 |
+
mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B)
|
529 |
+
num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching, but assumes num_decoder_tokens is the same across the batch
|
530 |
+
|
531 |
+
# Add arange multiplied by small constant to mask so they get sorted in a deterministic way
|
532 |
+
mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
|
533 |
+
ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1)
|
534 |
+
# ids_restore = torch.argsort(ids_shuffle, dim=1)
|
535 |
+
ids_keep = ids_shuffle[:, :num_decoder_tokens]
|
536 |
+
|
537 |
+
# Same as in forward_mask_decoder
|
538 |
+
decoder_ids = torch.gather(decoder_ids_all, dim=1, index=ids_keep)
|
539 |
+
decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
|
540 |
+
decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
|
541 |
+
mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
|
542 |
+
mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep)
|
543 |
+
|
544 |
+
decoder_ids[decoder_mask] = 0
|
545 |
+
decoder_emb[decoder_mask] = 0.
|
546 |
+
mod_mask[decoder_mask] = -1
|
547 |
+
|
548 |
+
return decoder_ids, decoder_emb, decoder_mask, mod_mask, mod_pos
|
549 |
+
|
550 |
+
def merge_sequences(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"):
|
551 |
+
device = mod_dict[target_mod]['tensor'].device
|
552 |
+
# Get input ids
|
553 |
+
input_ids = mod_dict[target_mod]['tensor'].squeeze().detach().cpu()
|
554 |
+
input_ids = input_ids[mod_dict[target_mod]['input_mask'].squeeze().detach().cpu() == 0]
|
555 |
+
input_ids = input_ids.tolist()
|
556 |
+
|
557 |
+
if len(input_ids) == 0:
|
558 |
+
input_ids = [text_tokenizer.get_vocab()[default_sentinel]]
|
559 |
+
|
560 |
+
# Get predicted ids
|
561 |
+
pred_ids = pred_ids.squeeze().detach().cpu().tolist()
|
562 |
+
if isinstance(pred_ids, int):
|
563 |
+
pred_ids = [pred_ids]
|
564 |
+
|
565 |
+
# Get sentinel ids using the tokenizer
|
566 |
+
sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values())
|
567 |
+
# Perform merging
|
568 |
+
merged_ids = merge_span_masking(input_ids, pred_ids, sentinel_ids)
|
569 |
+
merged_ids = torch.tensor(merged_ids).unsqueeze(0)
|
570 |
+
# Create new dict
|
571 |
+
new_input_mask = torch.zeros_like(merged_ids, dtype=torch.bool)
|
572 |
+
new_target_mask = torch.ones_like(merged_ids, dtype=torch.bool)
|
573 |
+
new_dict = {'tensor': merged_ids.to(device),
|
574 |
+
'input_mask': new_input_mask.to(device),
|
575 |
+
'target_mask': new_target_mask.to(device)}
|
576 |
+
new_dict['decoder_attention_mask'] = torch.zeros_like(new_target_mask, dtype=torch.bool)
|
577 |
+
|
578 |
+
mod_dict[target_mod] = new_dict
|
579 |
+
return mod_dict
|
580 |
+
|
581 |
+
def merge_sequences_batched(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"):
|
582 |
+
# Unbatches and calls merge sequence per batch, then regroups it into a batch
|
583 |
+
|
584 |
+
pad_id = text_tokenizer.token_to_id("[PAD]")
|
585 |
+
|
586 |
+
B = mod_dict[target_mod]['tensor'].shape[0]
|
587 |
+
device = mod_dict[target_mod]['tensor'].device
|
588 |
+
|
589 |
+
tensors = torch.split(mod_dict[target_mod]['tensor'], 1)
|
590 |
+
input_masks = torch.split(mod_dict[target_mod]['input_mask'], 1)
|
591 |
+
pred_ids = torch.split(pred_ids, 1)
|
592 |
+
|
593 |
+
input_dicts = []
|
594 |
+
for t, im in zip(tensors, input_masks):
|
595 |
+
d = {target_mod: {'tensor': t, 'input_mask': im}}
|
596 |
+
input_dicts.append(d)
|
597 |
+
|
598 |
+
merged_tensors = []
|
599 |
+
merged_input_masks = []
|
600 |
+
merged_target_masks = []
|
601 |
+
merged_seq_lens = []
|
602 |
+
for input_d, pi in zip(input_dicts, pred_ids):
|
603 |
+
# Output of merge_sequences is mod_dict with modified target mod
|
604 |
+
merged_d = self.merge_sequences(input_d, pi, target_mod, text_tokenizer, default_sentinel)[target_mod]
|
605 |
+
merged_tensors.append(merged_d['tensor'])
|
606 |
+
merged_input_masks.append(merged_d['input_mask'])
|
607 |
+
merged_target_masks.append(merged_d['input_mask'])
|
608 |
+
merged_seq_lens.append(merged_d['tensor'].shape[1])
|
609 |
+
|
610 |
+
|
611 |
+
max_seq_len = max(merged_seq_lens)
|
612 |
+
|
613 |
+
for i in range(len(merged_tensors)):
|
614 |
+
# Right pad all tensors
|
615 |
+
p1d = (0, max_seq_len - merged_seq_lens[i])
|
616 |
+
merged_tensors[i] = F.pad(merged_tensors[i], p1d, "constant",pad_id)
|
617 |
+
merged_input_masks[i] = F.pad(merged_input_masks[i], p1d, "constant", True)
|
618 |
+
merged_target_masks[i] = F.pad(merged_target_masks[i], p1d, "constant", True)
|
619 |
+
|
620 |
+
new_dict = {'tensor': torch.cat(merged_tensors, dim=0).to(device),
|
621 |
+
'input_mask': torch.cat(merged_input_masks, dim=0).to(device),
|
622 |
+
'target_mask': torch.cat(merged_target_masks, dim=0).to(device)}
|
623 |
+
new_dict['decoder_attention_mask'] = torch.zeros_like(new_dict['target_mask'], dtype=torch.bool)
|
624 |
+
|
625 |
+
mod_dict[target_mod] = new_dict
|
626 |
+
return mod_dict
|
627 |
+
|
628 |
+
def forward_enc_dec_maskgit_batched(self, mod_dict, target_mod, seed=None):
|
629 |
+
# Encoder
|
630 |
+
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
|
631 |
+
for mod, d in mod_dict.items()
|
632 |
+
if mod in self.model.encoder_embeddings}
|
633 |
+
encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
|
634 |
+
x = encoder_tokens + encoder_emb
|
635 |
+
x = self.model.forward_encoder(x, encoder_mask)
|
636 |
+
|
637 |
+
# Decoder
|
638 |
+
context = self.model.decoder_proj_context(x) + encoder_emb
|
639 |
+
decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
|
640 |
+
decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_maskgit(decoder_mod_dict, target_mod, seed=seed)
|
641 |
+
y = decoder_tokens + decoder_emb
|
642 |
+
y = self.model.forward_decoder(y, context, encoder_mask, None)
|
643 |
+
B, N, D = y.shape
|
644 |
+
logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod]
|
645 |
+
logits = logits.reshape(B, N, -1)
|
646 |
+
|
647 |
+
|
648 |
+
return logits, mod_pos
|
649 |
+
|
650 |
+
def maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None):
|
651 |
+
logits, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed)
|
652 |
+
|
653 |
+
# MaskGIT sampling
|
654 |
+
top_samples, top_indices = self.select_tokens_batched(logits, num_select,
|
655 |
+
temperature=temperature, top_k=top_k, top_p=top_p)
|
656 |
+
# Update mod dict
|
657 |
+
# We rely on gather / scatter for batched operations
|
658 |
+
top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select)
|
659 |
+
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples)
|
660 |
+
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
|
661 |
+
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
|
662 |
+
|
663 |
+
return mod_dict
|
664 |
+
|
665 |
+
def guided_maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p,
|
666 |
+
conditioning=[], guidance_scale=1.0, seed=None, write_all_predictions=False):
|
667 |
+
|
668 |
+
### 1 - First pass, with conditioning
|
669 |
+
logits_cond, _ = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed)
|
670 |
+
|
671 |
+
### 2 - Second pass, without conditioning
|
672 |
+
mod_dict_uncond = copy.deepcopy(mod_dict)
|
673 |
+
for mod in conditioning:
|
674 |
+
if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']:
|
675 |
+
mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod)
|
676 |
+
elif self.model.modality_info[mod]['type'] in ['seq_emb']:
|
677 |
+
mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod)
|
678 |
+
else:
|
679 |
+
mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod)
|
680 |
+
|
681 |
+
logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict_uncond, target_mod, seed=seed)
|
682 |
+
|
683 |
+
### 3 - Classifier-free guidance
|
684 |
+
logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale
|
685 |
+
|
686 |
+
### 4 - MaskGIT sampling
|
687 |
+
top_samples, top_indices, all_samples = self.select_tokens_batched(
|
688 |
+
logits, num_select,
|
689 |
+
temperature=temperature, top_k=top_k, top_p=top_p,
|
690 |
+
return_all_samples=True
|
691 |
+
)
|
692 |
+
|
693 |
+
### 5 - Update mod dict
|
694 |
+
# We rely on gather / scatter for batched operations
|
695 |
+
top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select)
|
696 |
+
if write_all_predictions:
|
697 |
+
mod_dict[target_mod]['tensor'][:, mod_pos] = all_samples
|
698 |
+
else:
|
699 |
+
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples)
|
700 |
+
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
|
701 |
+
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
|
702 |
+
|
703 |
+
return mod_dict
|
704 |
+
|
705 |
+
def multi_guided_maskgit_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod, num_select,
|
706 |
+
temperature, top_k, top_p, seed=None, write_all_predictions=False):
|
707 |
+
|
708 |
+
### 1 - Conditional forward passes (one for each guided condition)
|
709 |
+
logits_cond_all = []
|
710 |
+
for cond_dict in cond_dicts:
|
711 |
+
logits_cond_i, _ = self.forward_enc_dec_maskgit_batched(cond_dict, target_mod, seed=seed)
|
712 |
+
logits_cond_all.append(logits_cond_i)
|
713 |
+
|
714 |
+
### 2 - Unconditional forward pass
|
715 |
+
logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(uncond_dict, target_mod, seed=seed)
|
716 |
+
|
717 |
+
### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)}
|
718 |
+
# See https://arxiv.org/abs/2206.01714
|
719 |
+
logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0)
|
720 |
+
|
721 |
+
### 4 - MaskGIT sampling
|
722 |
+
top_samples, top_indices, all_samples = self.select_tokens_batched(
|
723 |
+
logits, num_select,
|
724 |
+
temperature=temperature, top_k=top_k, top_p=top_p,
|
725 |
+
return_all_samples=True
|
726 |
+
)
|
727 |
+
|
728 |
+
### 5 - Update mod dict with newly generated tokens
|
729 |
+
# We rely on gather / scatter for batched operations
|
730 |
+
top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select)
|
731 |
+
if write_all_predictions:
|
732 |
+
uncond_dict[target_mod]['tensor'][:, mod_pos] = all_samples
|
733 |
+
else:
|
734 |
+
uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, top_pos, top_samples)
|
735 |
+
uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
|
736 |
+
uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
|
737 |
+
# Update conditioning dicts
|
738 |
+
for i in range(len(cond_dicts)):
|
739 |
+
cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, top_pos, top_samples)
|
740 |
+
cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
|
741 |
+
cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
|
742 |
+
|
743 |
+
return uncond_dict, cond_dicts
|
744 |
+
|
745 |
+
def forward_enc_dec_roar_batched(self, mod_dict, target_mod, num_select, seed=None):
|
746 |
+
# Encoder
|
747 |
+
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
|
748 |
+
for mod, d in mod_dict.items()
|
749 |
+
if mod in self.model.encoder_embeddings}
|
750 |
+
encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
|
751 |
+
x = encoder_tokens + encoder_emb
|
752 |
+
x = self.model.forward_encoder(x, encoder_mask)
|
753 |
+
|
754 |
+
# Decoder
|
755 |
+
context = self.model.decoder_proj_context(x) + encoder_emb
|
756 |
+
decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
|
757 |
+
decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_roar(decoder_mod_dict, target_mod, num_select, seed=seed)
|
758 |
+
y = decoder_tokens + decoder_emb
|
759 |
+
y = self.model.forward_decoder(y, context, encoder_mask, None)
|
760 |
+
B, N, D = y.shape
|
761 |
+
logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod]
|
762 |
+
logits = logits.reshape(B, N, -1)
|
763 |
+
|
764 |
+
return logits, mod_pos
|
765 |
+
|
766 |
+
def roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None):
|
767 |
+
"""ROAR = Random Order Autoregression"""
|
768 |
+
|
769 |
+
logits, mod_pos = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed)
|
770 |
+
|
771 |
+
# Simple sampling
|
772 |
+
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p)
|
773 |
+
|
774 |
+
# Update mod dict
|
775 |
+
# We rely on scatter for batched operations
|
776 |
+
select_pos = mod_pos
|
777 |
+
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples)
|
778 |
+
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
|
779 |
+
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
|
780 |
+
|
781 |
+
return mod_dict
|
782 |
+
|
783 |
+
def guided_roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p,
|
784 |
+
conditioning=[], guidance_scale=1.0, seed=None):
|
785 |
+
"""ROAR = Random Order Autoregression"""
|
786 |
+
|
787 |
+
### 1 - First pass, with conditioning
|
788 |
+
logits_cond, _ = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed)
|
789 |
+
|
790 |
+
### 2 - Second pass, without conditioning
|
791 |
+
mod_dict_uncond = copy.deepcopy(mod_dict)
|
792 |
+
for mod in conditioning:
|
793 |
+
if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']:
|
794 |
+
mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod)
|
795 |
+
elif self.model.modality_info[mod]['type'] in ['seq_emb']:
|
796 |
+
mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod)
|
797 |
+
else:
|
798 |
+
mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod)
|
799 |
+
|
800 |
+
logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(mod_dict_uncond, target_mod, num_select, seed=seed)
|
801 |
+
|
802 |
+
### 3 - Classifier-free guidance
|
803 |
+
logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale
|
804 |
+
|
805 |
+
### 4 - Simple sampling
|
806 |
+
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p)
|
807 |
+
|
808 |
+
### 5 - Update mod dict
|
809 |
+
# We rely on gather / scatter for batched operations
|
810 |
+
select_pos = mod_pos
|
811 |
+
mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples)
|
812 |
+
mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
|
813 |
+
mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
|
814 |
+
|
815 |
+
return mod_dict
|
816 |
+
|
817 |
+
def multi_guided_roar_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod,
|
818 |
+
num_select, temperature, top_k, top_p, seed=None):
|
819 |
+
|
820 |
+
### 1 - Conditional forward passes (one for each guided condition)
|
821 |
+
logits_cond_all = []
|
822 |
+
for cond_dict in cond_dicts:
|
823 |
+
logits_cond_i, _ = self.forward_enc_dec_roar_batched(cond_dict, target_mod, num_select, seed=seed)
|
824 |
+
logits_cond_all.append(logits_cond_i)
|
825 |
+
|
826 |
+
### 2 - Unconditional forward pass
|
827 |
+
logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(uncond_dict, target_mod, num_select, seed=seed)
|
828 |
+
|
829 |
+
### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)}
|
830 |
+
# See https://arxiv.org/abs/2206.01714
|
831 |
+
logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0)
|
832 |
+
|
833 |
+
### 4 - Simple sampling
|
834 |
+
samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p)
|
835 |
+
|
836 |
+
### 5 - Update mod dict
|
837 |
+
# We rely on gather / scatter for batched operations
|
838 |
+
select_pos = mod_pos
|
839 |
+
uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, select_pos, samples)
|
840 |
+
uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
|
841 |
+
uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
|
842 |
+
# Update conditioning dicts
|
843 |
+
for i in range(len(cond_dicts)):
|
844 |
+
cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, select_pos, samples)
|
845 |
+
cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
|
846 |
+
cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
|
847 |
+
|
848 |
+
return uncond_dict, cond_dicts
|
849 |
+
|
850 |
+
def autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float,
|
851 |
+
use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None, seed=None):
|
852 |
+
|
853 |
+
# Encoder
|
854 |
+
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
|
855 |
+
for mod, d in mod_dict.items()
|
856 |
+
if mod in self.model.encoder_embeddings}
|
857 |
+
encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
|
858 |
+
x = encoder_tokens + encoder_emb
|
859 |
+
x = self.model.forward_encoder(x, encoder_mask) # B, N, D
|
860 |
+
|
861 |
+
# Get batch size
|
862 |
+
B = x.shape[0]
|
863 |
+
|
864 |
+
# Decoder
|
865 |
+
context = self.model.decoder_proj_context(x) + encoder_emb
|
866 |
+
decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
|
867 |
+
decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict, target_mod, seed=seed)
|
868 |
+
device = decoder_ids.device
|
869 |
+
seq_len = self.model.modality_info[target_mod]['max_tokens']
|
870 |
+
|
871 |
+
if use_eos and eos_token is None:
|
872 |
+
# The eos_token is the final sentinel token provided
|
873 |
+
eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all
|
874 |
+
if use_eos:
|
875 |
+
eos_token = eos_token.to(device)
|
876 |
+
|
877 |
+
# If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token)
|
878 |
+
out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device)
|
879 |
+
# Set decoder_tokens to None, we do not use them for decoding
|
880 |
+
decoder_ids = None
|
881 |
+
|
882 |
+
# If all samples of the batch have eos, return early
|
883 |
+
if use_eos and (out == eos_token).any(dim=-1).all():
|
884 |
+
return out
|
885 |
+
|
886 |
+
y_emb = decoder_emb[:, :seq_len]
|
887 |
+
seq_len = y_emb.shape[1]
|
888 |
+
|
889 |
+
# Auto-regressive decoding and sampling
|
890 |
+
for i in range(seq_len):
|
891 |
+
cur_len = out.shape[1]
|
892 |
+
# Convert ids into word embeddings and add corresponding posembs + modemb
|
893 |
+
y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len]
|
894 |
+
# Build causal mask
|
895 |
+
causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1)
|
896 |
+
causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B)
|
897 |
+
|
898 |
+
y = self.model.forward_decoder(y, context, encoder_mask, causal_mask)
|
899 |
+
logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask[:, :cur_len])[target_mod]
|
900 |
+
logits = rearrange(logits, "(b n) d -> b n d", b=B, n=cur_len)
|
901 |
+
last_logits = logits[:, -1]
|
902 |
+
|
903 |
+
# Sample token for the newly generated logit
|
904 |
+
if np.isclose(temperature, 0, atol=1e-10):
|
905 |
+
sample = torch.argmax(last_logits, dim=-1, keepdim=True)
|
906 |
+
else:
|
907 |
+
filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p)
|
908 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
909 |
+
sample = torch.multinomial(probs, 1)
|
910 |
+
out = torch.cat((out, sample), dim=-1)
|
911 |
+
|
912 |
+
if use_eos and (out == eos_token).any(dim=-1).all():
|
913 |
+
break
|
914 |
+
|
915 |
+
mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer)
|
916 |
+
|
917 |
+
return mod_dict
|
918 |
+
|
919 |
+
def guided_autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float,
|
920 |
+
use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None,
|
921 |
+
conditioning=[], guidance_scale=1.0, seed=None):
|
922 |
+
|
923 |
+
### 1 - Encoder forward pass, with conditioning
|
924 |
+
|
925 |
+
# Encoder
|
926 |
+
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
|
927 |
+
for mod, d in mod_dict.items()
|
928 |
+
if mod in self.model.encoder_embeddings}
|
929 |
+
encoder_tokens, encoder_emb, encoder_mask_cond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
|
930 |
+
x = encoder_tokens + encoder_emb
|
931 |
+
x = self.model.forward_encoder(x, encoder_mask_cond) # B, N, D
|
932 |
+
|
933 |
+
# Get batch size
|
934 |
+
B = x.shape[0]
|
935 |
+
|
936 |
+
# Decoder
|
937 |
+
context_cond = self.model.decoder_proj_context(x) + encoder_emb
|
938 |
+
decoder_mod_dict_cond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
|
939 |
+
decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_cond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_cond, target_mod, seed=seed)
|
940 |
+
device = decoder_ids.device
|
941 |
+
seq_len = self.model.modality_info[target_mod]['max_tokens']
|
942 |
+
|
943 |
+
|
944 |
+
### 2 - Encoder forward pass, without conditioning
|
945 |
+
|
946 |
+
mod_dict_uncond = copy.deepcopy(mod_dict)
|
947 |
+
for mod in conditioning:
|
948 |
+
if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']:
|
949 |
+
mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod)
|
950 |
+
elif self.model.modality_info[mod]['type'] in ['seq_emb']:
|
951 |
+
mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod)
|
952 |
+
else:
|
953 |
+
mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod)
|
954 |
+
|
955 |
+
# Encoder
|
956 |
+
encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
|
957 |
+
for mod, d in mod_dict_uncond.items()
|
958 |
+
if mod in self.model.encoder_embeddings}
|
959 |
+
encoder_tokens, encoder_emb, encoder_mask_uncond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
|
960 |
+
x = encoder_tokens + encoder_emb
|
961 |
+
x = self.model.forward_encoder(x, encoder_mask_uncond) # B, N, D
|
962 |
+
|
963 |
+
# Decoder
|
964 |
+
context_uncond = self.model.decoder_proj_context(x) + encoder_emb
|
965 |
+
decoder_mod_dict_uncond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
|
966 |
+
decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_uncond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_uncond, target_mod, seed=seed)
|
967 |
+
|
968 |
+
|
969 |
+
if use_eos and eos_token is None:
|
970 |
+
# The eos_token is the final sentinel token provided
|
971 |
+
eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all
|
972 |
+
if use_eos:
|
973 |
+
eos_token = eos_token.to(device)
|
974 |
+
|
975 |
+
# If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token)
|
976 |
+
out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device)
|
977 |
+
# Set decoder_tokens to None, we do not use them for decoding
|
978 |
+
decoder_ids = None
|
979 |
+
|
980 |
+
# If all samples of the batch have eos, return early
|
981 |
+
if use_eos and (out == eos_token).any(dim=-1).all():
|
982 |
+
return out
|
983 |
+
|
984 |
+
y_emb = decoder_emb[:, :seq_len]
|
985 |
+
seq_len = y_emb.shape[1]
|
986 |
+
|
987 |
+
### 3 - Auto-regressive decoding and sampling
|
988 |
+
for i in range(seq_len):
|
989 |
+
cur_len = out.shape[1]
|
990 |
+
# Convert ids into word embeddings and add corresponding posembs + modemb
|
991 |
+
y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len]
|
992 |
+
# Build causal mask
|
993 |
+
causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1)
|
994 |
+
causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B)
|
995 |
+
|
996 |
+
### 3a - Decoder forward pass, with conditioning
|
997 |
+
y_cond = self.model.forward_decoder(y, context_cond, encoder_mask_cond, causal_mask)
|
998 |
+
logits_cond = self.model.forward_logits(y_cond, decoder_mod_dict_cond, decoder_mod_mask_cond[:, :cur_len])[target_mod]
|
999 |
+
logits_cond = rearrange(logits_cond, "(b n) d -> b n d", b=B, n=cur_len)
|
1000 |
+
last_logits_cond = logits_cond[:, -1]
|
1001 |
+
|
1002 |
+
### 3b - Decoder forward pass, without conditioning
|
1003 |
+
y_uncond = self.model.forward_decoder(y, context_uncond, encoder_mask_uncond, causal_mask)
|
1004 |
+
logits_uncond = self.model.forward_logits(y_uncond, decoder_mod_dict_uncond, decoder_mod_mask_uncond[:, :cur_len])[target_mod]
|
1005 |
+
logits_uncond = rearrange(logits_uncond, "(b n) d -> b n d", b=B, n=cur_len)
|
1006 |
+
last_logits_uncond = logits_uncond[:, -1]
|
1007 |
+
|
1008 |
+
### 3c - Classifier-free guidance
|
1009 |
+
last_logits = last_logits_uncond + (last_logits_cond - last_logits_uncond) * guidance_scale
|
1010 |
+
|
1011 |
+
# Sample token for the newly generated logit
|
1012 |
+
if np.isclose(temperature, 0, atol=1e-10):
|
1013 |
+
sample = torch.argmax(last_logits, dim=-1, keepdim=True)
|
1014 |
+
else:
|
1015 |
+
filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p)
|
1016 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
1017 |
+
sample = torch.multinomial(probs, 1)
|
1018 |
+
out = torch.cat((out, sample), dim=-1)
|
1019 |
+
|
1020 |
+
if use_eos and (out == eos_token).any(dim=-1).all():
|
1021 |
+
break
|
1022 |
+
|
1023 |
+
mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer)
|
1024 |
+
|
1025 |
+
return mod_dict
|
1026 |
+
|
1027 |
+
|
1028 |
+
@torch.no_grad()
|
1029 |
+
def generate(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None):
|
1030 |
+
""" Generates a sequence of tokens from the input modalities.
|
1031 |
+
:param mod_dict: Dictionary of modalities.
|
1032 |
+
:param schedule: Schedule of modalities to use.
|
1033 |
+
List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}.
|
1034 |
+
:param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering).
|
1035 |
+
:param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering).
|
1036 |
+
:param text_tokenizer: Text tokenizer.
|
1037 |
+
:param verbose: Whether to print progress.
|
1038 |
+
:param seed: Random seed.
|
1039 |
+
:return: Generated mod dict.
|
1040 |
+
"""
|
1041 |
+
|
1042 |
+
# Input embedding -> tokenizes the modalities - Many are placeholder for now
|
1043 |
+
mod_dict = copy.deepcopy(mod_dict)
|
1044 |
+
|
1045 |
+
for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose):
|
1046 |
+
target_mod = schedule_step_info['target_domain']
|
1047 |
+
temp = schedule_step_info['temperature']
|
1048 |
+
cfg_scale = schedule_step_info.get('cfg_scale', 1.0)
|
1049 |
+
cfg_conditioning = schedule_step_info.get('cfg_cond_domains', [])
|
1050 |
+
seed_i = seed + step if seed is not None else None
|
1051 |
+
|
1052 |
+
if self.model.modality_info[target_mod]['type'] == 'img':
|
1053 |
+
scheme = schedule_step_info['scheme']
|
1054 |
+
num_select = schedule_step_info['num_tokens']
|
1055 |
+
|
1056 |
+
if scheme.lower() == 'maskgit':
|
1057 |
+
if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
|
1058 |
+
mod_dict = self.maskgit_step_batched(
|
1059 |
+
mod_dict, target_mod, num_select, temperature=temp,
|
1060 |
+
top_k=top_k, top_p=top_p, seed=seed_i
|
1061 |
+
)
|
1062 |
+
else:
|
1063 |
+
mod_dict = self.guided_maskgit_step_batched(
|
1064 |
+
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
|
1065 |
+
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i
|
1066 |
+
)
|
1067 |
+
elif scheme.lower() == 'roar':
|
1068 |
+
if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
|
1069 |
+
mod_dict = self.roar_step_batched(
|
1070 |
+
mod_dict, target_mod, num_select, temperature=temp,
|
1071 |
+
top_k=top_k, top_p=top_p, seed=seed_i
|
1072 |
+
)
|
1073 |
+
else:
|
1074 |
+
mod_dict = self.guided_roar_step_batched(
|
1075 |
+
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
|
1076 |
+
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i
|
1077 |
+
)
|
1078 |
+
else:
|
1079 |
+
raise ValueError("Invalid sampling scheme")
|
1080 |
+
elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']:
|
1081 |
+
if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
|
1082 |
+
mod_dict = self.autoregressive_step_batched(
|
1083 |
+
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
|
1084 |
+
text_tokenizer=text_tokenizer, seed=seed_i
|
1085 |
+
)
|
1086 |
+
else:
|
1087 |
+
mod_dict = self.guided_autoregressive_step_batched(
|
1088 |
+
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
|
1089 |
+
text_tokenizer=text_tokenizer, conditioning=cfg_conditioning,
|
1090 |
+
guidance_scale=cfg_scale, seed=seed_i
|
1091 |
+
)
|
1092 |
+
else:
|
1093 |
+
raise ValueError("Invalid schedule")
|
1094 |
+
|
1095 |
+
return mod_dict
|
1096 |
+
|
1097 |
+
|
1098 |
+
@torch.no_grad()
|
1099 |
+
def generate_iter(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None):
|
1100 |
+
""" Iterator that generates a sequence of tokens from the input modalities step by step.
|
1101 |
+
:param mod_dict: Dictionary of modalities.
|
1102 |
+
:param schedule: Schedule of modalities to use.
|
1103 |
+
List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}.
|
1104 |
+
:param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering).
|
1105 |
+
:param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering).
|
1106 |
+
:param text_tokenizer: Text tokenizer.
|
1107 |
+
:param verbose: Whether to print progress.
|
1108 |
+
:param seed: Random seed.
|
1109 |
+
:return: Iterator of generated mod dict.
|
1110 |
+
"""
|
1111 |
+
|
1112 |
+
# Input embedding -> tokenizes the modalities - Many are placeholder for now
|
1113 |
+
mod_dict = copy.deepcopy(mod_dict)
|
1114 |
+
|
1115 |
+
for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose):
|
1116 |
+
target_mod = schedule_step_info['target_domain']
|
1117 |
+
temp = schedule_step_info['temperature']
|
1118 |
+
cfg_scale = schedule_step_info.get('cfg_scale', 1.0)
|
1119 |
+
cfg_conditioning = schedule_step_info.get('cfg_cond_domains', [])
|
1120 |
+
seed_i = seed + step if seed is not None else None
|
1121 |
+
|
1122 |
+
if self.model.modality_info[target_mod]['type'] == 'img':
|
1123 |
+
scheme = schedule_step_info['scheme']
|
1124 |
+
num_select = schedule_step_info['num_tokens']
|
1125 |
+
|
1126 |
+
if scheme.lower() == 'maskgit':
|
1127 |
+
if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
|
1128 |
+
mod_dict = self.maskgit_step_batched(
|
1129 |
+
mod_dict, target_mod, num_select, temperature=temp,
|
1130 |
+
top_k=top_k, top_p=top_p, seed=seed_i
|
1131 |
+
)
|
1132 |
+
else:
|
1133 |
+
mod_dict = self.guided_maskgit_step_batched(
|
1134 |
+
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
|
1135 |
+
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i,
|
1136 |
+
write_all_predictions=True
|
1137 |
+
)
|
1138 |
+
elif scheme.lower() == 'roar':
|
1139 |
+
if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
|
1140 |
+
mod_dict = self.roar_step_batched(
|
1141 |
+
mod_dict, target_mod, num_select, temperature=temp,
|
1142 |
+
top_k=top_k, top_p=top_p, seed=seed_i
|
1143 |
+
)
|
1144 |
+
else:
|
1145 |
+
mod_dict = self.guided_roar_step_batched(
|
1146 |
+
mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
|
1147 |
+
conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i
|
1148 |
+
)
|
1149 |
+
else:
|
1150 |
+
raise ValueError("Invalid sampling scheme")
|
1151 |
+
elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']:
|
1152 |
+
if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
|
1153 |
+
mod_dict = self.autoregressive_step_batched(
|
1154 |
+
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
|
1155 |
+
text_tokenizer=text_tokenizer, seed=seed_i
|
1156 |
+
)
|
1157 |
+
else:
|
1158 |
+
mod_dict = self.guided_autoregressive_step_batched(
|
1159 |
+
mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
|
1160 |
+
text_tokenizer=text_tokenizer, conditioning=cfg_conditioning,
|
1161 |
+
guidance_scale=cfg_scale, seed=seed_i
|
1162 |
+
)
|
1163 |
+
else:
|
1164 |
+
raise ValueError("Invalid schedule")
|
1165 |
+
|
1166 |
+
yield mod_dict
|
1167 |
+
|
1168 |
+
@torch.no_grad()
|
1169 |
+
def generate_multi_guided(self, uncond_dict, cond_dicts, schedule, top_k=0.0, top_p=0.0,
|
1170 |
+
text_tokenizer=None, verbose=False, seed=None):
|
1171 |
+
# Generation function for multiple weighted conditions
|
1172 |
+
|
1173 |
+
# To detect when a modality has finished generating, we keep track of the current target modality
|
1174 |
+
cur_target_mod = schedule[0]['target_domain']
|
1175 |
+
|
1176 |
+
uncond_dict = copy.deepcopy(uncond_dict)
|
1177 |
+
cond_dicts = copy.deepcopy(cond_dicts)
|
1178 |
+
|
1179 |
+
# Add the to-be-generated modality to the conditional dicts
|
1180 |
+
for i in range(len(cond_dicts)):
|
1181 |
+
cond_dicts[i][cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod])
|
1182 |
+
|
1183 |
+
for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose):
|
1184 |
+
target_mod = schedule_step_info['target_domain']
|
1185 |
+
temp = schedule_step_info['temperature']
|
1186 |
+
num_select = schedule_step_info['num_tokens']
|
1187 |
+
cond_weights = schedule_step_info['cfg_scale']
|
1188 |
+
|
1189 |
+
# Once a modality is fully generated, add it as a new condition
|
1190 |
+
if cur_target_mod != target_mod:
|
1191 |
+
for i in range(len(cond_dicts)):
|
1192 |
+
# Remove the previously generated modality from the conditionings
|
1193 |
+
del cond_dicts[i][cur_target_mod]
|
1194 |
+
# Add the next modality to be generated to the conditionings
|
1195 |
+
cond_dicts[i][target_mod] = copy.deepcopy(uncond_dict[target_mod])
|
1196 |
+
|
1197 |
+
# Remove the fully generated modality from the unconditional dict inputs
|
1198 |
+
uncond_dict[cur_target_mod]['input_mask'][:] = True
|
1199 |
+
|
1200 |
+
# Add the previously generated modality as an additional condition
|
1201 |
+
new_cond = {}
|
1202 |
+
new_cond[cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod])
|
1203 |
+
new_cond[cur_target_mod]['input_mask'][:] = False
|
1204 |
+
new_cond[cur_target_mod]['target_mask'][:] = True
|
1205 |
+
new_cond[target_mod] = copy.deepcopy(uncond_dict[target_mod])
|
1206 |
+
cond_dicts.append(new_cond)
|
1207 |
+
|
1208 |
+
cur_target_mod = target_mod
|
1209 |
+
|
1210 |
+
if self.model.modality_info[target_mod]['type'] == 'img':
|
1211 |
+
scheme = schedule_step_info['scheme']
|
1212 |
+
|
1213 |
+
if scheme.lower() == 'maskgit':
|
1214 |
+
uncond_dict, cond_dicts = self.multi_guided_maskgit_step_batched(
|
1215 |
+
uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed
|
1216 |
+
)
|
1217 |
+
elif scheme.lower() == 'roar':
|
1218 |
+
uncond_dict, cond_dicts = self.multi_guided_roar_step_batched(
|
1219 |
+
uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed
|
1220 |
+
)
|
1221 |
+
else:
|
1222 |
+
raise ValueError("Invalid sampling scheme")
|
1223 |
+
|
1224 |
+
else:
|
1225 |
+
raise NotImplementedError("Only image modalities are supported for now")
|
1226 |
+
|
1227 |
+
return uncond_dict
|
1228 |
+
|
1229 |
+
@torch.no_grad()
|
1230 |
+
def generate_sam_dense(self, mod_dict, schedule, text_tokenizer, batch_size=16,
|
1231 |
+
key='sam_instance', top_k=0.0, top_p=0.0, seed=None, verbose=False):
|
1232 |
+
# Generation function for dense SAM instance prediction
|
1233 |
+
|
1234 |
+
device = mod_dict[list(mod_dict.keys())[0]]['tensor'].device
|
1235 |
+
mod_dict = copy.deepcopy(mod_dict)
|
1236 |
+
# Repeat the input batch to match the batch size
|
1237 |
+
expanded_batch = expand_to_batch(copy.deepcopy(mod_dict), batch_size=batch_size)
|
1238 |
+
|
1239 |
+
# Filter the schedule to only include the key domain
|
1240 |
+
schedule = [s for s in schedule if s['target_domain'] == key]
|
1241 |
+
|
1242 |
+
out_dict = self.generate(
|
1243 |
+
expanded_batch, schedule, text_tokenizer=text_tokenizer,
|
1244 |
+
verbose=verbose, seed=seed,
|
1245 |
+
top_p=top_p, top_k=top_k,
|
1246 |
+
)
|
1247 |
+
|
1248 |
+
# Merge the batch generated sequences into one sequence
|
1249 |
+
sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values())
|
1250 |
+
merged_seq = []
|
1251 |
+
|
1252 |
+
for i in range(batch_size):
|
1253 |
+
|
1254 |
+
input_seq = out_dict[key]['tensor'][i]
|
1255 |
+
input_seq = input_seq[out_dict[key]['input_mask'][i] == 0]
|
1256 |
+
input_seq = input_seq.tolist()
|
1257 |
+
|
1258 |
+
target_seq = out_dict[key]['tensor'][i]
|
1259 |
+
target_seq = target_seq[out_dict[key]['target_mask'][i] == 0]
|
1260 |
+
target_seq = target_seq.tolist()
|
1261 |
+
|
1262 |
+
merged_seq.extend(merge_span_masking(input_seq, target_seq, sentinel_ids=sentinel_ids))
|
1263 |
+
|
1264 |
+
merged_seq = torch.tensor(merged_seq, device=device).unsqueeze(0)
|
1265 |
+
|
1266 |
+
mod_dict[key] = {
|
1267 |
+
'tensor': merged_seq,
|
1268 |
+
'input_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device),
|
1269 |
+
'target_mask': torch.ones(merged_seq.shape, dtype=torch.bool, device=device),
|
1270 |
+
'decoder_attention_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device),
|
1271 |
+
}
|
1272 |
+
|
1273 |
+
return mod_dict
|
fourm/models/lora_utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import List, Set, Optional, Type
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
SELF_ATTENTION_MODULES = {'Attention', 'NormAttention'}
|
21 |
+
CROSS_ATTENTION_MODULES = {'CrossAttention', 'NormCrossAttention'}
|
22 |
+
ATTENTION_MODULES = SELF_ATTENTION_MODULES | CROSS_ATTENTION_MODULES
|
23 |
+
MLP_MODULES = {'Mlp', 'GatedMlp', 'SwiGLUFFNFused'} # SwiGLUFFNFused is from DINOv2
|
24 |
+
TRANSFORMER_MODULES = ATTENTION_MODULES | MLP_MODULES
|
25 |
+
|
26 |
+
|
27 |
+
def get_LoRA_module_names(id: str) -> Set[str]:
|
28 |
+
""" Returns a list of module names that are LoRA-adapted for the given id. """
|
29 |
+
id = id.lower()
|
30 |
+
if id in ['selfattn', 'selfattention', 'self_attn', 'self_attention']:
|
31 |
+
return SELF_ATTENTION_MODULES
|
32 |
+
elif id in ['crossattn', 'crossattention', 'cross_attn', 'cross_attention']:
|
33 |
+
return CROSS_ATTENTION_MODULES
|
34 |
+
elif id in ['attn', 'attention']:
|
35 |
+
return ATTENTION_MODULES
|
36 |
+
elif id in ['mlp']:
|
37 |
+
return MLP_MODULES
|
38 |
+
elif id in ['all', 'transformer']:
|
39 |
+
return TRANSFORMER_MODULES
|
40 |
+
else:
|
41 |
+
raise ValueError(f'Unknown LoRA module id {id}.')
|
42 |
+
|
43 |
+
|
44 |
+
class LoRAWrapper(nn.Module):
|
45 |
+
"""Low-Rank Adaptation Wrapper for linear layers.
|
46 |
+
See https://arxiv.org/abs/2106.09685
|
47 |
+
|
48 |
+
Args:
|
49 |
+
linear: nn.Linear layer to wrap
|
50 |
+
rank: Rank of adaptation matrix B@A
|
51 |
+
scale: x = W_0@x + scale * B@A@x
|
52 |
+
num_packed_linear: Set to > 1 when wrapping e.g. packed kv, or qkv attention weights.
|
53 |
+
Weights will be initialized as if num_packed_linear = 1, but the LoRA bottleneck will
|
54 |
+
be num_packed_linear times larger.
|
55 |
+
"""
|
56 |
+
def __init__(self, linear: nn.Module, rank: int = 4, scale: float = 1.0, num_packed_linear: int = 1):
|
57 |
+
super().__init__()
|
58 |
+
self.rank = rank
|
59 |
+
self.scale = scale
|
60 |
+
self.in_features, self.out_features = linear.in_features, linear.out_features
|
61 |
+
assert num_packed_linear * rank <= min(self.in_features, self.out_features), \
|
62 |
+
f'LoRA rank {num_packed_linear} * {rank} must be less or equal than {min(self.in_features, self.out_features)}'
|
63 |
+
|
64 |
+
self.linear = linear
|
65 |
+
self.lora_down = nn.Linear(self.in_features, num_packed_linear*rank, bias=False)
|
66 |
+
self.lora_up = nn.Linear(num_packed_linear*rank, self.out_features, bias=False)
|
67 |
+
|
68 |
+
nn.init.normal_(self.lora_down.weight, std=1/rank)
|
69 |
+
nn.init.zeros_(self.lora_up.weight)
|
70 |
+
|
71 |
+
def fuse_LoRA_into_linear(self) -> nn.Linear:
|
72 |
+
""" Returns a single nn.Linear layer with the LoRA matrix fused into the original one. """
|
73 |
+
fused_linear = nn.Linear(self.in_features, self.out_features, bias=self.linear.bias is not None)
|
74 |
+
fused_linear.weight.data = self.linear.weight + self.scale * (self.lora_up.weight @ self.lora_down.weight)
|
75 |
+
if self.linear.bias is not None:
|
76 |
+
fused_linear.bias.data = self.linear.bias
|
77 |
+
return fused_linear
|
78 |
+
|
79 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
80 |
+
""" LoRA adapted linear layer forward pass. """
|
81 |
+
return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
|
82 |
+
|
83 |
+
|
84 |
+
def _find_modules(
|
85 |
+
model,
|
86 |
+
ancestor_class: Optional[Set[str]] = None,
|
87 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
88 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoRAWrapper],
|
89 |
+
):
|
90 |
+
"""
|
91 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
92 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
93 |
+
|
94 |
+
Returns all matching modules, along with the parent of those moduless and the
|
95 |
+
names they are referenced by.
|
96 |
+
|
97 |
+
Adapted from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
98 |
+
"""
|
99 |
+
# Get the targets we should replace all linears under
|
100 |
+
if ancestor_class is not None:
|
101 |
+
ancestors = (
|
102 |
+
module
|
103 |
+
for module in model.modules()
|
104 |
+
if module.__class__.__name__ in ancestor_class
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
# this, incase you want to naively iterate over all modules.
|
108 |
+
ancestors = [module for module in model.modules()]
|
109 |
+
|
110 |
+
# For each target find every linear_class module that isn't a child of a LoRA layer
|
111 |
+
for ancestor in ancestors:
|
112 |
+
for fullname, module in ancestor.named_modules():
|
113 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
114 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
115 |
+
*path, name = fullname.split(".")
|
116 |
+
parent = ancestor
|
117 |
+
while path:
|
118 |
+
parent = parent.get_submodule(path.pop(0))
|
119 |
+
# Skip this linear if it's a child of a LoRA layer
|
120 |
+
if exclude_children_of and any(
|
121 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
122 |
+
):
|
123 |
+
continue
|
124 |
+
# Otherwise, yield it
|
125 |
+
yield parent, name, module
|
126 |
+
|
127 |
+
|
128 |
+
def inject_trainable_LoRA(
|
129 |
+
model: nn.Module,
|
130 |
+
rank: int = 4,
|
131 |
+
scale: float = 1.0,
|
132 |
+
target_replace_modules: Set[str] = ATTENTION_MODULES
|
133 |
+
) -> None:
|
134 |
+
"""Replaces all linear layers of the specified modules with LoRA-adapted linear layers.
|
135 |
+
Modifies the model in-place.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
model: nn.Module to modify
|
139 |
+
rank: Rank of adaptation matrix B@A
|
140 |
+
scale: x = W_0@x + scale * B@A@x
|
141 |
+
target_replace_modules: Set of module names to replace linear layers in.
|
142 |
+
"""
|
143 |
+
for _module, name, _child_module in _find_modules(
|
144 |
+
model, target_replace_modules, search_class=[nn.Linear]
|
145 |
+
):
|
146 |
+
if sorted(name) == sorted('qkv'):
|
147 |
+
num_packed_linear = 3
|
148 |
+
elif sorted(name) in [sorted('kv'), sorted('qk'), sorted('qv')]:
|
149 |
+
num_packed_linear = 2
|
150 |
+
else:
|
151 |
+
num_packed_linear = 1
|
152 |
+
|
153 |
+
_module._modules[name] = LoRAWrapper(_child_module, rank=rank, scale=scale, num_packed_linear=num_packed_linear)
|
154 |
+
|
155 |
+
|
156 |
+
def fuse_LoRA_into_linear(
|
157 |
+
model: nn.Module,
|
158 |
+
target_replace_modules: Set[str] = ATTENTION_MODULES
|
159 |
+
) -> None:
|
160 |
+
"""Fuses all LoRA-adapted linear layers back into single linear layers.
|
161 |
+
Modifies the model in-place.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
model: nn.Module to modify
|
165 |
+
target_replace_modules: Set of module names to replace linear layers in.
|
166 |
+
"""
|
167 |
+
for _module, name, _child_module in _find_modules(
|
168 |
+
model, target_replace_modules, search_class=[LoRAWrapper]
|
169 |
+
):
|
170 |
+
_module._modules[name] = _module._modules[name].fuse_LoRA_into_linear()
|
171 |
+
|
172 |
+
|
173 |
+
def unfreeze_all_LoRA_layers(model: nn.Module) -> None:
|
174 |
+
""" Unfreezes all LoRA-adapted linear layers. """
|
175 |
+
for name, param in model.named_parameters():
|
176 |
+
if 'lora' in name:
|
177 |
+
param.requires_grad = True
|
fourm/utils/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .misc import *
|
2 |
+
from .checkpoint import *
|
3 |
+
from .timm.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
4 |
+
from .data_constants import *
|
5 |
+
from .dist import *
|
6 |
+
from .logger import *
|
7 |
+
from .timm.metrics import AverageMeter, accuracy
|
8 |
+
from .timm.mixup import FastCollateMixup, Mixup
|
9 |
+
from .timm.model import freeze, get_state_dict, unfreeze, unwrap_model
|
10 |
+
from .timm.model_builder import create_model
|
11 |
+
from .timm.model_ema import ModelEma, ModelEmaV2
|
12 |
+
from .native_scaler import NativeScalerWithGradNormCount
|
13 |
+
from .scheduler import cosine_scheduler, constant_scheduler, inverse_sqrt_scheduler
|
14 |
+
from .optim_factory import create_optimizer
|
15 |
+
from .timm.registry import model_entrypoint, register_model
|
16 |
+
from .timm.transforms import *
|
17 |
+
from .timm.transforms_factory import create_transform
|
18 |
+
from .tokenizer.text_tokenizer import *
|
19 |
+
from .s3_utils import *
|
20 |
+
from .run_name import *
|
21 |
+
from .generation_datasets import *
|
22 |
+
from .seeds import *
|
fourm/utils/checkpoint.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# Based on the timm code base
|
16 |
+
# https://github.com/huggingface/pytorch-image-models
|
17 |
+
# --------------------------------------------------------
|
18 |
+
import io
|
19 |
+
import os
|
20 |
+
import ast
|
21 |
+
import json
|
22 |
+
from pathlib import Path
|
23 |
+
from safetensors.torch import load as load_st
|
24 |
+
|
25 |
+
import torch
|
26 |
+
|
27 |
+
from .dist import save_on_main, is_main_process
|
28 |
+
from .timm.model import get_state_dict
|
29 |
+
from .s3_utils import save_on_s3
|
30 |
+
|
31 |
+
|
32 |
+
def _load_checkpoint_for_ema(model_ema, checkpoint):
|
33 |
+
"""
|
34 |
+
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
|
35 |
+
"""
|
36 |
+
mem_file = io.BytesIO()
|
37 |
+
torch.save(checkpoint, mem_file)
|
38 |
+
mem_file.seek(0)
|
39 |
+
model_ema._load_checkpoint(mem_file)
|
40 |
+
|
41 |
+
|
42 |
+
def load_state_dict(model, state_dict, prefix='', ignore_missing=''):
|
43 |
+
missing_keys = []
|
44 |
+
unexpected_keys = []
|
45 |
+
error_msgs = []
|
46 |
+
# copy state_dict so _load_from_state_dict can modify it
|
47 |
+
metadata = getattr(state_dict, '_metadata', None)
|
48 |
+
state_dict = state_dict.copy()
|
49 |
+
if metadata is not None:
|
50 |
+
state_dict._metadata = metadata
|
51 |
+
|
52 |
+
def load(module, prefix=''):
|
53 |
+
local_metadata = {} if metadata is None else metadata.get(
|
54 |
+
prefix[:-1], {})
|
55 |
+
module._load_from_state_dict(
|
56 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
57 |
+
for name, child in module._modules.items():
|
58 |
+
if child is not None:
|
59 |
+
load(child, prefix + name + '.')
|
60 |
+
|
61 |
+
load(model, prefix=prefix)
|
62 |
+
|
63 |
+
warn_missing_keys = []
|
64 |
+
ignore_missing_keys = []
|
65 |
+
for key in missing_keys:
|
66 |
+
keep_flag = True
|
67 |
+
for ignore_key in ignore_missing.split('|'):
|
68 |
+
if ignore_key in key:
|
69 |
+
keep_flag = False
|
70 |
+
break
|
71 |
+
if keep_flag:
|
72 |
+
warn_missing_keys.append(key)
|
73 |
+
else:
|
74 |
+
ignore_missing_keys.append(key)
|
75 |
+
|
76 |
+
missing_keys = warn_missing_keys
|
77 |
+
|
78 |
+
if len(missing_keys) > 0:
|
79 |
+
print("Weights of {} not initialized from pretrained model: {}".format(
|
80 |
+
model.__class__.__name__, missing_keys))
|
81 |
+
if len(unexpected_keys) > 0:
|
82 |
+
print("Weights from pretrained model not used in {}: {}".format(
|
83 |
+
model.__class__.__name__, unexpected_keys))
|
84 |
+
if len(ignore_missing_keys) > 0:
|
85 |
+
print("Ignored weights of {} not initialized from pretrained model: {}".format(
|
86 |
+
model.__class__.__name__, ignore_missing_keys))
|
87 |
+
if len(error_msgs) > 0:
|
88 |
+
print('\n'.join(error_msgs))
|
89 |
+
|
90 |
+
|
91 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, loss_balancer=None, model_ema=None, ckpt_name=None, use_s3=False, all_nodes=False):
|
92 |
+
output_dir = Path(args.output_dir)
|
93 |
+
epoch_name = str(epoch)
|
94 |
+
ckpt_name = ckpt_name or epoch_name
|
95 |
+
|
96 |
+
# Only create the save_dict on the main process, unless all_nodes is set to True
|
97 |
+
if is_main_process() or (all_nodes and args.gpu == 0):
|
98 |
+
checkpoint_path = os.path.join(output_dir, f'checkpoint-{ckpt_name}.pth')
|
99 |
+
|
100 |
+
to_save = {
|
101 |
+
'model': model_without_ddp.state_dict(),
|
102 |
+
'epoch': epoch,
|
103 |
+
'args': args,
|
104 |
+
'scaler': loss_scaler.state_dict(),
|
105 |
+
}
|
106 |
+
|
107 |
+
if optimizer is not None:
|
108 |
+
to_save['optimizer'] = optimizer.state_dict()
|
109 |
+
|
110 |
+
if loss_balancer is not None:
|
111 |
+
to_save['loss_balancer'] = loss_balancer.state_dict()
|
112 |
+
|
113 |
+
if model_ema is not None:
|
114 |
+
to_save['model_ema'] = get_state_dict(model_ema)
|
115 |
+
|
116 |
+
save_on_main(to_save, checkpoint_path)
|
117 |
+
|
118 |
+
if use_s3:
|
119 |
+
s3_path = os.path.join(args.s3_save_dir, f'checkpoint-{ckpt_name}.pth')
|
120 |
+
save_on_s3(checkpoint_path, s3_path, args.s3_endpoint)
|
121 |
+
|
122 |
+
|
123 |
+
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
|
124 |
+
output_dir = Path(args.output_dir)
|
125 |
+
# torch.amp
|
126 |
+
if args.auto_resume and len(args.resume) == 0:
|
127 |
+
import glob
|
128 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
|
129 |
+
latest_ckpt = -1
|
130 |
+
for ckpt in all_checkpoints:
|
131 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
132 |
+
if t.isdigit():
|
133 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
134 |
+
if latest_ckpt >= 0:
|
135 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
|
136 |
+
print("Auto resume checkpoint: %s" % args.resume)
|
137 |
+
|
138 |
+
if args.resume:
|
139 |
+
if args.resume.startswith('https'):
|
140 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
141 |
+
args.resume, map_location='cpu')
|
142 |
+
else:
|
143 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
144 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
145 |
+
print("Resume checkpoint %s" % args.resume)
|
146 |
+
|
147 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
148 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
149 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
150 |
+
|
151 |
+
if 'scaler' in checkpoint:
|
152 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
153 |
+
print("With optim & sched!")
|
154 |
+
|
155 |
+
if hasattr(args, 'model_ema') and args.model_ema:
|
156 |
+
_load_checkpoint_for_ema(model_ema, {'state_dict_ema': checkpoint['model_ema']})
|
157 |
+
print("With EMA!")
|
158 |
+
|
159 |
+
def parse_metadata(metadata_str):
|
160 |
+
metadata = {}
|
161 |
+
for k, v in metadata_str.items():
|
162 |
+
try:
|
163 |
+
v_parsed = ast.literal_eval(v)
|
164 |
+
except:
|
165 |
+
v_parsed = v
|
166 |
+
metadata[k] = v_parsed
|
167 |
+
return metadata
|
168 |
+
|
169 |
+
def load_safetensors(safetensors_path, return_metadata=True):
|
170 |
+
with open(safetensors_path, 'rb') as f:
|
171 |
+
data = f.read()
|
172 |
+
|
173 |
+
tensors = load_st(data)
|
174 |
+
|
175 |
+
if not return_metadata:
|
176 |
+
return tensors
|
177 |
+
|
178 |
+
n_header = data[:8]
|
179 |
+
n = int.from_bytes(n_header, "little")
|
180 |
+
metadata_bytes = data[8 : 8 + n]
|
181 |
+
header = json.loads(metadata_bytes)
|
182 |
+
metadata = header.get("__metadata__", {})
|
183 |
+
metadata = parse_metadata(metadata)
|
184 |
+
|
185 |
+
return tensors, metadata
|
fourm/utils/clip/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .clip import *
|
2 |
+
from .model import *
|
fourm/utils/clip/clip.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the CLIP code base
|
3 |
+
# https://github.com/openai/CLIP
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import hashlib
|
7 |
+
import os
|
8 |
+
import urllib
|
9 |
+
import warnings
|
10 |
+
from typing import Any, Union, List
|
11 |
+
from pkg_resources import packaging
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from .model import build_model
|
19 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
20 |
+
|
21 |
+
try:
|
22 |
+
from torchvision.transforms import InterpolationMode
|
23 |
+
BICUBIC = InterpolationMode.BICUBIC
|
24 |
+
except ImportError:
|
25 |
+
BICUBIC = Image.BICUBIC
|
26 |
+
|
27 |
+
|
28 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
29 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
30 |
+
|
31 |
+
|
32 |
+
__all__ = ["available_models", "load", "tokenize"]
|
33 |
+
_tokenizer = _Tokenizer()
|
34 |
+
|
35 |
+
_MODELS = {
|
36 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
37 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
38 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
39 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
40 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
41 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
42 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
43 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
def _download(url: str, root: str):
|
48 |
+
os.makedirs(root, exist_ok=True)
|
49 |
+
filename = os.path.basename(url)
|
50 |
+
|
51 |
+
expected_sha256 = url.split("/")[-2]
|
52 |
+
download_target = os.path.join(root, filename)
|
53 |
+
|
54 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
55 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
56 |
+
|
57 |
+
if os.path.isfile(download_target):
|
58 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
59 |
+
return download_target
|
60 |
+
else:
|
61 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
62 |
+
|
63 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
64 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
65 |
+
while True:
|
66 |
+
buffer = source.read(8192)
|
67 |
+
if not buffer:
|
68 |
+
break
|
69 |
+
|
70 |
+
output.write(buffer)
|
71 |
+
loop.update(len(buffer))
|
72 |
+
|
73 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
74 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
75 |
+
|
76 |
+
return download_target
|
77 |
+
|
78 |
+
|
79 |
+
def _convert_image_to_rgb(image):
|
80 |
+
return image.convert("RGB")
|
81 |
+
|
82 |
+
|
83 |
+
def _transform(n_px):
|
84 |
+
return Compose([
|
85 |
+
Resize(n_px, interpolation=BICUBIC),
|
86 |
+
CenterCrop(n_px),
|
87 |
+
_convert_image_to_rgb,
|
88 |
+
ToTensor(),
|
89 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
90 |
+
])
|
91 |
+
|
92 |
+
|
93 |
+
def available_models() -> List[str]:
|
94 |
+
"""Returns the names of available CLIP models"""
|
95 |
+
return list(_MODELS.keys())
|
96 |
+
|
97 |
+
|
98 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
99 |
+
"""Load a CLIP model
|
100 |
+
|
101 |
+
Parameters
|
102 |
+
----------
|
103 |
+
name : str
|
104 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
105 |
+
|
106 |
+
device : Union[str, torch.device]
|
107 |
+
The device to put the loaded model
|
108 |
+
|
109 |
+
jit : bool
|
110 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
111 |
+
|
112 |
+
download_root: str
|
113 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
114 |
+
|
115 |
+
Returns
|
116 |
+
-------
|
117 |
+
model : torch.nn.Module
|
118 |
+
The CLIP model
|
119 |
+
|
120 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
121 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
122 |
+
"""
|
123 |
+
if name in _MODELS:
|
124 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
125 |
+
elif os.path.isfile(name):
|
126 |
+
model_path = name
|
127 |
+
else:
|
128 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
129 |
+
|
130 |
+
try:
|
131 |
+
# loading JIT archive
|
132 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
133 |
+
state_dict = None
|
134 |
+
except RuntimeError:
|
135 |
+
# loading saved state dict
|
136 |
+
if jit:
|
137 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
138 |
+
jit = False
|
139 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
140 |
+
|
141 |
+
if not jit:
|
142 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
143 |
+
if str(device) == "cpu":
|
144 |
+
model.float()
|
145 |
+
return model, _transform(model.visual.input_resolution)
|
146 |
+
|
147 |
+
# patch the device names
|
148 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
149 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
150 |
+
|
151 |
+
def patch_device(module):
|
152 |
+
try:
|
153 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
154 |
+
except RuntimeError:
|
155 |
+
graphs = []
|
156 |
+
|
157 |
+
if hasattr(module, "forward1"):
|
158 |
+
graphs.append(module.forward1.graph)
|
159 |
+
|
160 |
+
for graph in graphs:
|
161 |
+
for node in graph.findAllNodes("prim::Constant"):
|
162 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
163 |
+
node.copyAttributes(device_node)
|
164 |
+
|
165 |
+
model.apply(patch_device)
|
166 |
+
patch_device(model.encode_image)
|
167 |
+
patch_device(model.encode_text)
|
168 |
+
|
169 |
+
# patch dtype to float32 on CPU
|
170 |
+
if str(device) == "cpu":
|
171 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
172 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
173 |
+
float_node = float_input.node()
|
174 |
+
|
175 |
+
def patch_float(module):
|
176 |
+
try:
|
177 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
178 |
+
except RuntimeError:
|
179 |
+
graphs = []
|
180 |
+
|
181 |
+
if hasattr(module, "forward1"):
|
182 |
+
graphs.append(module.forward1.graph)
|
183 |
+
|
184 |
+
for graph in graphs:
|
185 |
+
for node in graph.findAllNodes("aten::to"):
|
186 |
+
inputs = list(node.inputs())
|
187 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
188 |
+
if inputs[i].node()["value"] == 5:
|
189 |
+
inputs[i].node().copyAttributes(float_node)
|
190 |
+
|
191 |
+
model.apply(patch_float)
|
192 |
+
patch_float(model.encode_image)
|
193 |
+
patch_float(model.encode_text)
|
194 |
+
|
195 |
+
model.float()
|
196 |
+
|
197 |
+
return model, _transform(model.input_resolution.item())
|
198 |
+
|
199 |
+
|
200 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
201 |
+
"""
|
202 |
+
Returns the tokenized representation of given input string(s)
|
203 |
+
|
204 |
+
Parameters
|
205 |
+
----------
|
206 |
+
texts : Union[str, List[str]]
|
207 |
+
An input string or a list of input strings to tokenize
|
208 |
+
|
209 |
+
context_length : int
|
210 |
+
The context length to use; all CLIP models use 77 as the context length
|
211 |
+
|
212 |
+
truncate: bool
|
213 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
214 |
+
|
215 |
+
Returns
|
216 |
+
-------
|
217 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
218 |
+
"""
|
219 |
+
if isinstance(texts, str):
|
220 |
+
texts = [texts]
|
221 |
+
|
222 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
223 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
224 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
225 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
226 |
+
|
227 |
+
for i, tokens in enumerate(all_tokens):
|
228 |
+
if len(tokens) > context_length:
|
229 |
+
if truncate:
|
230 |
+
tokens = tokens[:context_length]
|
231 |
+
tokens[-1] = eot_token
|
232 |
+
else:
|
233 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
234 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
235 |
+
|
236 |
+
return result
|
fourm/utils/clip/model.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the CLIP code base
|
3 |
+
# https://github.com/openai/CLIP
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
from typing import Tuple, Union
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class Bottleneck(nn.Module):
|
16 |
+
expansion = 4
|
17 |
+
|
18 |
+
def __init__(self, inplanes, planes, stride=1):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
22 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
23 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
24 |
+
|
25 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
26 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
27 |
+
|
28 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
29 |
+
|
30 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
31 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
32 |
+
|
33 |
+
self.relu = nn.ReLU(inplace=True)
|
34 |
+
self.downsample = None
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
38 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
39 |
+
self.downsample = nn.Sequential(OrderedDict([
|
40 |
+
("-1", nn.AvgPool2d(stride)),
|
41 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
42 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
43 |
+
]))
|
44 |
+
|
45 |
+
def forward(self, x: torch.Tensor):
|
46 |
+
identity = x
|
47 |
+
|
48 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
49 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
50 |
+
out = self.avgpool(out)
|
51 |
+
out = self.bn3(self.conv3(out))
|
52 |
+
|
53 |
+
if self.downsample is not None:
|
54 |
+
identity = self.downsample(x)
|
55 |
+
|
56 |
+
out += identity
|
57 |
+
out = self.relu(out)
|
58 |
+
return out
|
59 |
+
|
60 |
+
|
61 |
+
class AttentionPool2d(nn.Module):
|
62 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
63 |
+
super().__init__()
|
64 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
65 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
66 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
67 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
68 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
69 |
+
self.num_heads = num_heads
|
70 |
+
|
71 |
+
def forward(self, x, return_all_tokens=False):
|
72 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
73 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
74 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
75 |
+
x, _ = F.multi_head_attention_forward(
|
76 |
+
query=x, key=x, value=x,
|
77 |
+
embed_dim_to_check=x.shape[-1],
|
78 |
+
num_heads=self.num_heads,
|
79 |
+
q_proj_weight=self.q_proj.weight,
|
80 |
+
k_proj_weight=self.k_proj.weight,
|
81 |
+
v_proj_weight=self.v_proj.weight,
|
82 |
+
in_proj_weight=None,
|
83 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
84 |
+
bias_k=None,
|
85 |
+
bias_v=None,
|
86 |
+
add_zero_attn=False,
|
87 |
+
dropout_p=0,
|
88 |
+
out_proj_weight=self.c_proj.weight,
|
89 |
+
out_proj_bias=self.c_proj.bias,
|
90 |
+
use_separate_proj_weight=True,
|
91 |
+
training=self.training,
|
92 |
+
need_weights=False
|
93 |
+
)
|
94 |
+
if return_all_tokens:
|
95 |
+
return x
|
96 |
+
else:
|
97 |
+
return x[0]
|
98 |
+
|
99 |
+
|
100 |
+
class ModifiedResNet(nn.Module):
|
101 |
+
"""
|
102 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
103 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
104 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
105 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
109 |
+
super().__init__()
|
110 |
+
self.output_dim = output_dim
|
111 |
+
self.input_resolution = input_resolution
|
112 |
+
|
113 |
+
# the 3-layer stem
|
114 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
115 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
116 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
117 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
118 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
119 |
+
self.bn3 = nn.BatchNorm2d(width)
|
120 |
+
self.avgpool = nn.AvgPool2d(2)
|
121 |
+
self.relu = nn.ReLU(inplace=True)
|
122 |
+
|
123 |
+
# residual layers
|
124 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
125 |
+
self.layer1 = self._make_layer(width, layers[0])
|
126 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
127 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
128 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
129 |
+
|
130 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
131 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
132 |
+
|
133 |
+
def _make_layer(self, planes, blocks, stride=1):
|
134 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
135 |
+
|
136 |
+
self._inplanes = planes * Bottleneck.expansion
|
137 |
+
for _ in range(1, blocks):
|
138 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
139 |
+
|
140 |
+
return nn.Sequential(*layers)
|
141 |
+
|
142 |
+
def forward(self, x, return_side_out=False, return_all_tokens=False):
|
143 |
+
def stem(x):
|
144 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
145 |
+
x = self.relu(bn(conv(x)))
|
146 |
+
x = self.avgpool(x)
|
147 |
+
return x
|
148 |
+
out = []
|
149 |
+
x = x.type(self.conv1.weight.dtype)
|
150 |
+
x = stem(x)
|
151 |
+
x = self.layer1(x)
|
152 |
+
if return_side_out:
|
153 |
+
out.append(x)
|
154 |
+
x = self.layer2(x)
|
155 |
+
if return_side_out:
|
156 |
+
out.append(x)
|
157 |
+
x = self.layer3(x)
|
158 |
+
if return_side_out:
|
159 |
+
out.append(x)
|
160 |
+
x = self.layer4(x)
|
161 |
+
if return_side_out:
|
162 |
+
out.append(x)
|
163 |
+
x = self.attnpool(x, return_all_tokens)
|
164 |
+
out.append(x)
|
165 |
+
if len(out) == 1:
|
166 |
+
return x
|
167 |
+
else:
|
168 |
+
return out
|
169 |
+
|
170 |
+
|
171 |
+
class LayerNorm(nn.LayerNorm):
|
172 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
173 |
+
|
174 |
+
def forward(self, x: torch.Tensor):
|
175 |
+
orig_type = x.dtype
|
176 |
+
ret = super().forward(x.type(torch.float32))
|
177 |
+
return ret.type(orig_type)
|
178 |
+
|
179 |
+
|
180 |
+
class QuickGELU(nn.Module):
|
181 |
+
def forward(self, x: torch.Tensor):
|
182 |
+
return x * torch.sigmoid(1.702 * x)
|
183 |
+
|
184 |
+
|
185 |
+
class ResidualAttentionBlock(nn.Module):
|
186 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
187 |
+
super().__init__()
|
188 |
+
|
189 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
190 |
+
self.ln_1 = LayerNorm(d_model)
|
191 |
+
self.mlp = nn.Sequential(OrderedDict([
|
192 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
193 |
+
("gelu", QuickGELU()),
|
194 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
195 |
+
]))
|
196 |
+
self.ln_2 = LayerNorm(d_model)
|
197 |
+
self.attn_mask = attn_mask
|
198 |
+
|
199 |
+
def attention(self, x: torch.Tensor):
|
200 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
201 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
202 |
+
|
203 |
+
def forward(self, x: torch.Tensor):
|
204 |
+
x = x + self.attention(self.ln_1(x))
|
205 |
+
x = x + self.mlp(self.ln_2(x))
|
206 |
+
return x
|
207 |
+
|
208 |
+
|
209 |
+
class Transformer(nn.Module):
|
210 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
211 |
+
super().__init__()
|
212 |
+
self.width = width
|
213 |
+
self.layers = layers
|
214 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
215 |
+
|
216 |
+
def forward(self, x: torch.Tensor, return_intermediate_out: bool = False):
|
217 |
+
if return_intermediate_out:
|
218 |
+
output = []
|
219 |
+
for block in self.resblocks:
|
220 |
+
x = block(x)
|
221 |
+
output.append(x)
|
222 |
+
return output
|
223 |
+
|
224 |
+
return self.resblocks(x)
|
225 |
+
|
226 |
+
|
227 |
+
class VisionTransformer(nn.Module):
|
228 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
229 |
+
super().__init__()
|
230 |
+
self.input_resolution = input_resolution
|
231 |
+
self.patch_size = patch_size
|
232 |
+
self.output_dim = output_dim
|
233 |
+
self.width = width
|
234 |
+
self.heads = heads
|
235 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
236 |
+
|
237 |
+
scale = width ** -0.5
|
238 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
239 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
240 |
+
self.ln_pre = LayerNorm(width)
|
241 |
+
|
242 |
+
self.transformer = Transformer(width, layers, heads)
|
243 |
+
|
244 |
+
self.ln_post = LayerNorm(width)
|
245 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
246 |
+
|
247 |
+
def forward(self, x: torch.Tensor, return_all_tokens=False, return_all_final_tokens=False, return_final_tokens_no_cls=False, **kwargs):
|
248 |
+
|
249 |
+
B, nc, w, h = x.shape
|
250 |
+
|
251 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
252 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
253 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
254 |
+
|
255 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
256 |
+
|
257 |
+
if x.shape[1] != self.positional_embedding.shape[0]:
|
258 |
+
x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype)
|
259 |
+
else:
|
260 |
+
x = x + self.positional_embedding.to(x.dtype)
|
261 |
+
|
262 |
+
x = self.ln_pre(x)
|
263 |
+
|
264 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
265 |
+
x = self.transformer(x)
|
266 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
267 |
+
|
268 |
+
if return_all_tokens:
|
269 |
+
x = self.ln_post(x)
|
270 |
+
return x[:, 1:, :]
|
271 |
+
|
272 |
+
if return_all_final_tokens:
|
273 |
+
return self.ln_post(x) @ self.proj
|
274 |
+
|
275 |
+
if return_final_tokens_no_cls:
|
276 |
+
return self.ln_post(x)[:, 1:, :] @ self.proj
|
277 |
+
|
278 |
+
x = self.ln_post(x[:, 0, :])
|
279 |
+
|
280 |
+
if self.proj is not None:
|
281 |
+
x = x @ self.proj
|
282 |
+
|
283 |
+
return x
|
284 |
+
|
285 |
+
def interpolate_pos_encoding(self, x, w, h):
|
286 |
+
npatch = x.shape[1] - 1
|
287 |
+
N = self.positional_embedding.shape[0] - 1 # 256 for large
|
288 |
+
if npatch == N and w == h:
|
289 |
+
return self.positional_embedding
|
290 |
+
class_pos_embed = self.positional_embedding[[0]]
|
291 |
+
patch_pos_embed = self.positional_embedding[1:]
|
292 |
+
dim = x.shape[-1]
|
293 |
+
w0 = w // self.patch_size
|
294 |
+
h0 = h // self.patch_size
|
295 |
+
# we add a small number to avoid floating point error in the interpolation
|
296 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
297 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
298 |
+
patch_pos_embed = nn.functional.interpolate(
|
299 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
300 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
301 |
+
mode='bicubic',
|
302 |
+
)
|
303 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
304 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
305 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
306 |
+
|
307 |
+
|
308 |
+
class CLIP(nn.Module):
|
309 |
+
def __init__(self,
|
310 |
+
embed_dim: int, # 512
|
311 |
+
# vision
|
312 |
+
image_resolution: int, # 224
|
313 |
+
vision_layers: Union[Tuple[int, int, int, int], int], # 12
|
314 |
+
vision_width: int, # 768
|
315 |
+
vision_patch_size: int, # 16
|
316 |
+
# text
|
317 |
+
context_length: int, # 77
|
318 |
+
vocab_size: int, # 49408
|
319 |
+
transformer_width: int, # 512
|
320 |
+
transformer_heads: int, # 8
|
321 |
+
transformer_layers: int # 12
|
322 |
+
):
|
323 |
+
super().__init__()
|
324 |
+
self.context_length = context_length
|
325 |
+
|
326 |
+
if isinstance(vision_layers, (tuple, list)):
|
327 |
+
vision_heads = vision_width * 32 // 64
|
328 |
+
self.visual = ModifiedResNet(
|
329 |
+
layers=vision_layers,
|
330 |
+
output_dim=embed_dim,
|
331 |
+
heads=vision_heads,
|
332 |
+
input_resolution=image_resolution,
|
333 |
+
width=vision_width
|
334 |
+
)
|
335 |
+
else:
|
336 |
+
vision_heads = vision_width // 64
|
337 |
+
self.visual = VisionTransformer(
|
338 |
+
input_resolution=image_resolution,
|
339 |
+
patch_size=vision_patch_size,
|
340 |
+
width=vision_width,
|
341 |
+
layers=vision_layers,
|
342 |
+
heads=vision_heads,
|
343 |
+
output_dim=embed_dim
|
344 |
+
)
|
345 |
+
|
346 |
+
self.transformer = Transformer(
|
347 |
+
width=transformer_width,
|
348 |
+
layers=transformer_layers,
|
349 |
+
heads=transformer_heads,
|
350 |
+
attn_mask=self.build_attention_mask()
|
351 |
+
)
|
352 |
+
|
353 |
+
self.vocab_size = vocab_size
|
354 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
355 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
356 |
+
self.ln_final = LayerNorm(transformer_width)
|
357 |
+
|
358 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
359 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
360 |
+
|
361 |
+
self.initialize_parameters()
|
362 |
+
|
363 |
+
def initialize_parameters(self):
|
364 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
365 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
366 |
+
|
367 |
+
if isinstance(self.visual, ModifiedResNet):
|
368 |
+
if self.visual.attnpool is not None:
|
369 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
370 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
371 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
372 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
373 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
374 |
+
|
375 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
376 |
+
for name, param in resnet_block.named_parameters():
|
377 |
+
if name.endswith("bn3.weight"):
|
378 |
+
nn.init.zeros_(param)
|
379 |
+
|
380 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
381 |
+
attn_std = self.transformer.width ** -0.5
|
382 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
383 |
+
for block in self.transformer.resblocks:
|
384 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
385 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
386 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
387 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
388 |
+
|
389 |
+
if self.text_projection is not None:
|
390 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
391 |
+
|
392 |
+
def build_attention_mask(self):
|
393 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
394 |
+
# pytorch uses additive attention mask; fill with -inf
|
395 |
+
mask = torch.empty(self.context_length, self.context_length)
|
396 |
+
mask.fill_(float("-inf"))
|
397 |
+
mask.triu_(1) # zero out the lower diagonal
|
398 |
+
return mask
|
399 |
+
|
400 |
+
@property
|
401 |
+
def dtype(self):
|
402 |
+
return self.visual.conv1.weight.dtype
|
403 |
+
|
404 |
+
def encode_image(self, image, return_side_out=False, return_all_tokens=False, return_all_final_tokens=False, **kwargs):
|
405 |
+
return self.visual(image.type(self.dtype), return_all_tokens, return_all_final_tokens, **kwargs)
|
406 |
+
|
407 |
+
def encode_text(self, text, return_all_tokens=False, return_patch_tokens=False):
|
408 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
409 |
+
|
410 |
+
x = x + self.positional_embedding.type(self.dtype)
|
411 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
412 |
+
x = self.transformer(x)
|
413 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
414 |
+
x = self.ln_final(x).type(self.dtype)
|
415 |
+
|
416 |
+
if return_patch_tokens:
|
417 |
+
return x
|
418 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
419 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
420 |
+
if return_all_tokens:
|
421 |
+
x = x @ self.text_projection
|
422 |
+
else:
|
423 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
424 |
+
return x
|
425 |
+
|
426 |
+
def forward(self, image, text):
|
427 |
+
image_features = self.encode_image(image)
|
428 |
+
text_features = self.encode_text(text)
|
429 |
+
|
430 |
+
# normalized features
|
431 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
432 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
433 |
+
|
434 |
+
# cosine similarity as logits
|
435 |
+
logit_scale = self.logit_scale.exp()
|
436 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
437 |
+
logits_per_text = logits_per_image.t()
|
438 |
+
|
439 |
+
# shape = [global_batch_size, global_batch_size]
|
440 |
+
return logits_per_image, logits_per_text
|
441 |
+
|
442 |
+
|
443 |
+
def convert_weights(model: nn.Module):
|
444 |
+
"""Convert applicable model parameters to fp16"""
|
445 |
+
|
446 |
+
def _convert_weights_to_fp16(l):
|
447 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
448 |
+
l.weight.data = l.weight.data.half()
|
449 |
+
if l.bias is not None:
|
450 |
+
l.bias.data = l.bias.data.half()
|
451 |
+
|
452 |
+
if isinstance(l, nn.MultiheadAttention):
|
453 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
454 |
+
tensor = getattr(l, attr)
|
455 |
+
if tensor is not None:
|
456 |
+
tensor.data = tensor.data.half()
|
457 |
+
|
458 |
+
for name in ["text_projection", "proj"]:
|
459 |
+
if hasattr(l, name):
|
460 |
+
attr = getattr(l, name)
|
461 |
+
if attr is not None:
|
462 |
+
attr.data = attr.data.half()
|
463 |
+
|
464 |
+
model.apply(_convert_weights_to_fp16)
|
465 |
+
|
466 |
+
|
467 |
+
def build_model(state_dict: dict):
|
468 |
+
vit = "visual.proj" in state_dict
|
469 |
+
|
470 |
+
if vit:
|
471 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
472 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
473 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
474 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
475 |
+
image_resolution = vision_patch_size * grid_size
|
476 |
+
else:
|
477 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
478 |
+
vision_layers = tuple(counts)
|
479 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
480 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
481 |
+
vision_patch_size = None
|
482 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
483 |
+
image_resolution = output_width * 32
|
484 |
+
|
485 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
486 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
487 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
488 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
489 |
+
transformer_heads = transformer_width // 64
|
490 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
491 |
+
|
492 |
+
model = CLIP(
|
493 |
+
embed_dim,
|
494 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
495 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
496 |
+
)
|
497 |
+
|
498 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
499 |
+
if key in state_dict:
|
500 |
+
del state_dict[key]
|
501 |
+
|
502 |
+
convert_weights(model)
|
503 |
+
model.load_state_dict(state_dict)
|
504 |
+
return model.eval()
|
fourm/utils/clip/simple_tokenizer.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the CLIP code base
|
3 |
+
# https://github.com/openai/CLIP
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import gzip
|
7 |
+
import html
|
8 |
+
import os
|
9 |
+
from functools import lru_cache
|
10 |
+
|
11 |
+
import ftfy
|
12 |
+
import regex as re
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def default_bpe():
|
17 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
18 |
+
|
19 |
+
|
20 |
+
@lru_cache()
|
21 |
+
def bytes_to_unicode():
|
22 |
+
"""
|
23 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
24 |
+
The reversible bpe codes work on unicode strings.
|
25 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
26 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
27 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
28 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
29 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
30 |
+
"""
|
31 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
32 |
+
cs = bs[:]
|
33 |
+
n = 0
|
34 |
+
for b in range(2**8):
|
35 |
+
if b not in bs:
|
36 |
+
bs.append(b)
|
37 |
+
cs.append(2**8+n)
|
38 |
+
n += 1
|
39 |
+
cs = [chr(n) for n in cs]
|
40 |
+
return dict(zip(bs, cs))
|
41 |
+
|
42 |
+
|
43 |
+
def get_pairs(word):
|
44 |
+
"""Return set of symbol pairs in a word.
|
45 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
46 |
+
"""
|
47 |
+
pairs = set()
|
48 |
+
prev_char = word[0]
|
49 |
+
for char in word[1:]:
|
50 |
+
pairs.add((prev_char, char))
|
51 |
+
prev_char = char
|
52 |
+
return pairs
|
53 |
+
|
54 |
+
|
55 |
+
def basic_clean(text):
|
56 |
+
text = ftfy.fix_text(text)
|
57 |
+
text = html.unescape(html.unescape(text))
|
58 |
+
return text.strip()
|
59 |
+
|
60 |
+
|
61 |
+
def whitespace_clean(text):
|
62 |
+
text = re.sub(r'\s+', ' ', text)
|
63 |
+
text = text.strip()
|
64 |
+
return text
|
65 |
+
|
66 |
+
|
67 |
+
class SimpleTokenizer(object):
|
68 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
69 |
+
self.byte_encoder = bytes_to_unicode()
|
70 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
71 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
72 |
+
merges = merges[1:49152-256-2+1]
|
73 |
+
merges = [tuple(merge.split()) for merge in merges]
|
74 |
+
vocab = list(bytes_to_unicode().values())
|
75 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
76 |
+
for merge in merges:
|
77 |
+
vocab.append(''.join(merge))
|
78 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
79 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
80 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
81 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
82 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
83 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
84 |
+
|
85 |
+
def bpe(self, token):
|
86 |
+
if token in self.cache:
|
87 |
+
return self.cache[token]
|
88 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
89 |
+
pairs = get_pairs(word)
|
90 |
+
|
91 |
+
if not pairs:
|
92 |
+
return token+'</w>'
|
93 |
+
|
94 |
+
while True:
|
95 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
96 |
+
if bigram not in self.bpe_ranks:
|
97 |
+
break
|
98 |
+
first, second = bigram
|
99 |
+
new_word = []
|
100 |
+
i = 0
|
101 |
+
while i < len(word):
|
102 |
+
try:
|
103 |
+
j = word.index(first, i)
|
104 |
+
new_word.extend(word[i:j])
|
105 |
+
i = j
|
106 |
+
except:
|
107 |
+
new_word.extend(word[i:])
|
108 |
+
break
|
109 |
+
|
110 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
111 |
+
new_word.append(first+second)
|
112 |
+
i += 2
|
113 |
+
else:
|
114 |
+
new_word.append(word[i])
|
115 |
+
i += 1
|
116 |
+
new_word = tuple(new_word)
|
117 |
+
word = new_word
|
118 |
+
if len(word) == 1:
|
119 |
+
break
|
120 |
+
else:
|
121 |
+
pairs = get_pairs(word)
|
122 |
+
word = ' '.join(word)
|
123 |
+
self.cache[token] = word
|
124 |
+
return word
|
125 |
+
|
126 |
+
def encode(self, text):
|
127 |
+
bpe_tokens = []
|
128 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
129 |
+
for token in re.findall(self.pat, text):
|
130 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
131 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
132 |
+
return bpe_tokens
|
133 |
+
|
134 |
+
def decode(self, tokens):
|
135 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
136 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
137 |
+
return text
|
fourm/utils/data_constants.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
DEFAULT_CROP_PCT = 0.875
|
15 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
16 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
17 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
18 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
19 |
+
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
20 |
+
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
21 |
+
|
22 |
+
IMAGENET_SURFACE_NORMAL_MEAN = (0.501, 0.405, 0.137)
|
23 |
+
IMAGENET_SURFACE_NORMAL_STD = (0.114, 0.165, 0.081)
|
24 |
+
|
25 |
+
SEG_IGNORE_INDEX = 255
|
26 |
+
SEG_IGNORE_INDEX_V2 = 0
|
27 |
+
PAD_MASK_VALUE = 254
|
28 |
+
COCO_SEMSEG_NUM_CLASSES = 133 + 1 # One extra class for no-class
|
29 |
+
ADE20K_SEMSEG_NUM_CLASSES = 150 + 1 # One extra class for no-class
|
30 |
+
HYPERSIM_SEMSEG_NUM_CLASSES = 41
|
31 |
+
|
32 |
+
|
33 |
+
IMAGE_TASKS = {'rgb', 'depth', 'semseg', 'semseg_hypersim', 'semseg_coco', 'semseg_ade20k', 'normal'}
|
34 |
+
DETECTION_TASKS = {'det'} # 'det_coco', 'det_lvis'
|
35 |
+
TEXT_TASKS = {'caption'}
|
36 |
+
VISION_TASKS = IMAGE_TASKS | DETECTION_TASKS
|
37 |
+
SEQUENCE_TASKS = DETECTION_TASKS | TEXT_TASKS
|
38 |
+
|
39 |
+
NYU_MEAN = 2070.7764
|
40 |
+
NYU_STD = 777.5723
|
fourm/utils/dist.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# Based on DETR and MMCV code bases
|
16 |
+
# https://github.com/facebookresearch/detr
|
17 |
+
# https://github.com/open-mmlab/mmcv
|
18 |
+
# --------------------------------------------------------
|
19 |
+
|
20 |
+
import os
|
21 |
+
import pickle
|
22 |
+
import shutil
|
23 |
+
import sys
|
24 |
+
import tempfile
|
25 |
+
import datetime
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.distributed as dist
|
29 |
+
|
30 |
+
|
31 |
+
def setup_for_distributed(is_main):
|
32 |
+
"""
|
33 |
+
This function disables printing when not in main process
|
34 |
+
"""
|
35 |
+
import builtins as __builtin__
|
36 |
+
builtin_print = __builtin__.print
|
37 |
+
|
38 |
+
def print(*args, **kwargs):
|
39 |
+
force = kwargs.pop('force', False)
|
40 |
+
if is_main or force or kwargs.get('file', None) == sys.stderr:
|
41 |
+
builtin_print(*args, **kwargs)
|
42 |
+
|
43 |
+
__builtin__.print = print
|
44 |
+
|
45 |
+
|
46 |
+
def is_dist_avail_and_initialized():
|
47 |
+
if not dist.is_available():
|
48 |
+
return False
|
49 |
+
if not dist.is_initialized():
|
50 |
+
return False
|
51 |
+
return True
|
52 |
+
|
53 |
+
|
54 |
+
def get_world_size():
|
55 |
+
if not is_dist_avail_and_initialized():
|
56 |
+
return 1
|
57 |
+
return dist.get_world_size()
|
58 |
+
|
59 |
+
|
60 |
+
def get_rank():
|
61 |
+
if not is_dist_avail_and_initialized():
|
62 |
+
return 0
|
63 |
+
return dist.get_rank()
|
64 |
+
|
65 |
+
|
66 |
+
def is_main_process():
|
67 |
+
return get_rank() == 0
|
68 |
+
|
69 |
+
|
70 |
+
def save_on_main(*args, **kwargs):
|
71 |
+
if is_main_process():
|
72 |
+
torch.save(*args, **kwargs)
|
73 |
+
|
74 |
+
def save_on_all(*args, **kwargs):
|
75 |
+
torch.save(*args, **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
def init_distributed_mode(args):
|
79 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
80 |
+
args.rank = int(os.environ["RANK"])
|
81 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
82 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
83 |
+
else:
|
84 |
+
print('Not using distributed mode')
|
85 |
+
args.distributed = False
|
86 |
+
return
|
87 |
+
|
88 |
+
args.distributed = True
|
89 |
+
|
90 |
+
torch.cuda.set_device(args.gpu)
|
91 |
+
args.dist_backend = 'nccl'
|
92 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
93 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
94 |
+
# Set timeout to 1h20 in case some long download of dataset has to happen
|
95 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
96 |
+
world_size=args.world_size, rank=args.rank,
|
97 |
+
timeout=datetime.timedelta(4800))
|
98 |
+
torch.distributed.barrier()
|
99 |
+
if ("print_all" not in args) or (not args.print_all):
|
100 |
+
setup_for_distributed(args.rank == 0)
|
fourm/utils/fsdp_utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from .dist import save_on_main, is_main_process
|
20 |
+
from .s3_utils import save_on_s3
|
21 |
+
|
22 |
+
|
23 |
+
from torch.distributed.fsdp import (
|
24 |
+
FullyShardedDataParallel as FSDP,
|
25 |
+
FullStateDictConfig,
|
26 |
+
StateDictType,
|
27 |
+
)
|
28 |
+
|
29 |
+
from torch.distributed.fsdp.api import FullOptimStateDictConfig
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def save_model_fsdp(args, epoch, model, optimizer, model_ema=None, ckpt_name=None, use_s3=False):
|
34 |
+
output_dir = Path(args.output_dir)
|
35 |
+
epoch_name = str(epoch)
|
36 |
+
ckpt_name = ckpt_name or epoch_name
|
37 |
+
|
38 |
+
|
39 |
+
with FSDP.state_dict_type(model,
|
40 |
+
StateDictType.FULL_STATE_DICT,
|
41 |
+
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
42 |
+
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
|
43 |
+
):
|
44 |
+
|
45 |
+
model_state_dict = model.state_dict()
|
46 |
+
if optimizer is not None:
|
47 |
+
optimizer_state_dict = FSDP.optim_state_dict(model, optimizer)
|
48 |
+
else:
|
49 |
+
optimizer_state_dict = None
|
50 |
+
|
51 |
+
# Only create the save_dict on the main process, not needed or recommended to do so on all ranks
|
52 |
+
# This make save_on_main() redundant
|
53 |
+
if is_main_process():
|
54 |
+
checkpoint_path = os.path.join(output_dir, f'checkpoint-{ckpt_name}.pth')
|
55 |
+
|
56 |
+
to_save = {
|
57 |
+
'model': model_state_dict,
|
58 |
+
'epoch': epoch,
|
59 |
+
'args': args,
|
60 |
+
}
|
61 |
+
|
62 |
+
if optimizer is not None:
|
63 |
+
to_save['optimizer'] = optimizer_state_dict
|
64 |
+
|
65 |
+
if model_ema is not None:
|
66 |
+
print("Model EMA is currently not supported for FSDP")
|
67 |
+
# to_save['model_ema'] = get_state_dict(model_ema)
|
68 |
+
|
69 |
+
save_on_main(to_save, checkpoint_path)
|
70 |
+
|
71 |
+
if use_s3:
|
72 |
+
s3_path = os.path.join(args.s3_save_dir, f'checkpoint-{ckpt_name}.pth')
|
73 |
+
save_on_s3(checkpoint_path, s3_path, args.s3_endpoint)
|
74 |
+
|
75 |
+
|
76 |
+
def auto_load_model_fsdp(args, model, optimizer, model_ema=None):
|
77 |
+
output_dir = Path(args.output_dir)
|
78 |
+
if args.auto_resume and len(args.resume) == 0:
|
79 |
+
import glob
|
80 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
|
81 |
+
latest_ckpt = -1
|
82 |
+
for ckpt in all_checkpoints:
|
83 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
84 |
+
if t.isdigit():
|
85 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
86 |
+
if latest_ckpt >= 0:
|
87 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
|
88 |
+
print("Auto resume checkpoint: %s" % args.resume)
|
89 |
+
|
90 |
+
if args.resume:
|
91 |
+
if args.resume.startswith('https'):
|
92 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
93 |
+
args.resume, map_location='cpu')
|
94 |
+
else:
|
95 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
96 |
+
|
97 |
+
with FSDP.state_dict_type(
|
98 |
+
model,
|
99 |
+
StateDictType.FULL_STATE_DICT,
|
100 |
+
FullStateDictConfig(rank0_only=False),
|
101 |
+
FullOptimStateDictConfig(rank0_only=False),
|
102 |
+
):
|
103 |
+
|
104 |
+
model.load_state_dict(checkpoint['model'])
|
105 |
+
print("Resume checkpoint %s" % args.resume)
|
106 |
+
|
107 |
+
|
108 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
109 |
+
optimizer_state_dict = FSDP.optim_state_dict_to_load(checkpoint['optimizer'], model, optimizer)
|
110 |
+
optimizer.load_state_dict(optimizer_state_dict)
|
111 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
112 |
+
|
113 |
+
print("With optim & sched!")
|
114 |
+
|
115 |
+
if hasattr(args, 'model_ema') and args.model_ema:
|
116 |
+
print("Model EMA is currently not supported for FSDP")
|
fourm/utils/generation.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
import math
|
16 |
+
|
17 |
+
|
18 |
+
def sample_to_batch(mod_dict, device, domains):
|
19 |
+
mod_dict = {
|
20 |
+
modality: {k: v.unsqueeze(0).to(device, non_blocking=True) for k, v in d.items()}
|
21 |
+
for modality, d in mod_dict.items() if modality in domains
|
22 |
+
}
|
23 |
+
|
24 |
+
return mod_dict
|
25 |
+
|
26 |
+
|
27 |
+
def unbatch(tensor):
|
28 |
+
return tensor.detach().squeeze(0).cpu()
|
29 |
+
|
30 |
+
|
31 |
+
def batch_to_sample(mod_dict, domains):
|
32 |
+
mod_dict = {
|
33 |
+
modality: {k: unbatch(v) for k, v in d.items()}
|
34 |
+
for modality, d in mod_dict.items() if modality in domains
|
35 |
+
}
|
36 |
+
|
37 |
+
return mod_dict
|
38 |
+
|
39 |
+
|
40 |
+
def batch_to_device(mod_dict, device, domains):
|
41 |
+
mod_dict = {
|
42 |
+
modality: {k: v.to(device, non_blocking=True) for k, v in d.items()}
|
43 |
+
for modality, d in mod_dict.items() if modality in domains
|
44 |
+
}
|
45 |
+
|
46 |
+
return mod_dict
|
47 |
+
|
48 |
+
|
49 |
+
def cosine_schedule(num_steps, total_tokens):
|
50 |
+
iters = np.arange(num_steps)
|
51 |
+
base_value = 1
|
52 |
+
final_value = 0
|
53 |
+
schedule = np.array(
|
54 |
+
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
55 |
+
schedule_tokens = [round(total_tokens * i) for i in (schedule[:-1] - schedule[1:])]
|
56 |
+
schedule_tokens.append(total_tokens - sum(schedule_tokens))
|
57 |
+
return np.array(schedule_tokens)
|
58 |
+
|
59 |
+
|
60 |
+
def linear_schedule(num_steps, total_tokens):
|
61 |
+
schedule = np.linspace(0, total_tokens, num_steps + 1, dtype=int)
|
62 |
+
schedule_tokens = np.diff(schedule)[::-1]
|
63 |
+
schedule_tokens.sort() # Sorts the array in ascending order.
|
64 |
+
schedule_tokens = schedule_tokens[::-1] # Reverses the array to descending order.
|
65 |
+
return np.trim_zeros(schedule_tokens, 'b') # Trims trailing zeros.
|
66 |
+
|
67 |
+
|
68 |
+
def continue_schedule(schedule, num_current_tokens):
|
69 |
+
schedule_cumsum = np.cumsum(schedule)
|
70 |
+
keep_mask = schedule_cumsum > num_current_tokens
|
71 |
+
diff = schedule_cumsum[keep_mask][0] - num_current_tokens
|
72 |
+
new_schedule = schedule[keep_mask]
|
73 |
+
new_schedule[0] = diff
|
74 |
+
return new_schedule
|
75 |
+
|
76 |
+
|
77 |
+
def decreasing_temp_schedule(max, min, token_schedule):
|
78 |
+
schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule)
|
79 |
+
temp_schedule = np.array([min + (max - min) * (1 - s) for s in schedule_cumsum])
|
80 |
+
return temp_schedule
|
81 |
+
|
82 |
+
|
83 |
+
def onex_temp_schedule(max_t, min_t, token_schedule, power=0.5, min_linspace=1, max_linspace=100):
|
84 |
+
"""Abitrary temperature schedule for one over x"""
|
85 |
+
x = np.linspace(min_linspace, max_linspace, num=sum(token_schedule))
|
86 |
+
y = 1/(x**power)
|
87 |
+
y = y - min(y)
|
88 |
+
y = y / max(y)
|
89 |
+
unscaled_schedule = y
|
90 |
+
schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule)
|
91 |
+
unscaled_schedule = [(1 - cs) * us for us, cs in zip(unscaled_schedule, schedule_cumsum)]
|
92 |
+
|
93 |
+
temp_schedule = np.array([min_t + (max_t - min_t) * s for s in unscaled_schedule]).clip(min=1e-9)
|
94 |
+
return temp_schedule
|
95 |
+
|
96 |
+
|
97 |
+
def linear_temp_schedule(temp, token_schedule):
|
98 |
+
""" Temperature that decays the temperature inversely proportional to the token schedule. """
|
99 |
+
return np.concatenate([np.array([temp * 1.0]), (temp * (token_schedule.sum() - token_schedule.cumsum()) / token_schedule.sum())[:-1]]).clip(min=1e-9)
|
fourm/utils/generation_datasets/PartiPrompts.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
fourm/utils/generation_datasets/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .parti_prompts_dataset import *
|
2 |
+
from .empty_dataset import *
|
3 |
+
from .image_caption_dataset import *
|
fourm/utils/generation_datasets/empty_dataset.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
|
17 |
+
class EmptyDataset(Dataset):
|
18 |
+
"""Empty dataset"""
|
19 |
+
|
20 |
+
def __init__(self, dataset_size: int):
|
21 |
+
self.dataset_size = dataset_size
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
return {}
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return self.dataset_size
|
fourm/utils/generation_datasets/image_caption_dataset.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
from torch.utils.data import Dataset
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
17 |
+
|
18 |
+
from fourm.data.multimodal_dataset_folder import make_dataset, UNIFIED_EXTENSIONS
|
19 |
+
from fourm.data.modality_transforms import get_transform_key, RGBTransform, CaptionTransform, UnifiedDataTransform
|
20 |
+
|
21 |
+
|
22 |
+
class ImageCaptionDataset(Dataset):
|
23 |
+
"""
|
24 |
+
Similar to MultiModalDatasetFolder, but specialized for image-caption datasets.
|
25 |
+
"""
|
26 |
+
def __init__(self,
|
27 |
+
root: str,
|
28 |
+
augmenter: Optional[Callable] = None,
|
29 |
+
modality_paths: Dict[str, str] = None,
|
30 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
31 |
+
cache=False):
|
32 |
+
self.root = root
|
33 |
+
self.modality_paths = modality_paths or {}
|
34 |
+
|
35 |
+
self.modality_transforms = {
|
36 |
+
'rgb': RGBTransform(imagenet_default_mean_and_std=False),
|
37 |
+
'caption': CaptionTransform()
|
38 |
+
}
|
39 |
+
|
40 |
+
self.transform = UnifiedDataTransform(transforms_dict=self.modality_transforms, image_augmenter=augmenter)
|
41 |
+
|
42 |
+
classes, class_to_idx = self._find_classes(os.path.join(self.root, self.modality_paths.get('caption', 'caption')))
|
43 |
+
extensions = UNIFIED_EXTENSIONS if is_valid_file is None else None
|
44 |
+
|
45 |
+
samples = {
|
46 |
+
mod: make_dataset(
|
47 |
+
os.path.join(self.root, self.modality_paths.get(mod, mod)),
|
48 |
+
class_to_idx,
|
49 |
+
extensions,
|
50 |
+
is_valid_file,
|
51 |
+
cache_path=os.path.join(self.root, 'dataloader_cache', f'{self.modality_paths.get(mod, mod)}.pkl') if cache else None)
|
52 |
+
for mod in ['caption', 'rgb']
|
53 |
+
}
|
54 |
+
|
55 |
+
for mod, mod_samples in samples.items():
|
56 |
+
if len(mod_samples) == 0:
|
57 |
+
msg = "Found 0 logs in subfolders of: {}\n".format(os.path.join(self.root, self.modality_paths.get(mod, mod)))
|
58 |
+
if extensions is not None:
|
59 |
+
msg += "Supported extensions are: {}".format(",".join(extensions))
|
60 |
+
raise RuntimeError(msg)
|
61 |
+
|
62 |
+
self.extensions = extensions
|
63 |
+
self.classes = classes
|
64 |
+
self.class_to_idx = class_to_idx
|
65 |
+
self.samples = samples
|
66 |
+
|
67 |
+
def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
|
68 |
+
"""
|
69 |
+
Finds the class folders in a dataset.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
dir (string): Root directory path.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
|
76 |
+
|
77 |
+
Ensures:
|
78 |
+
No class is a subdirectory of another.
|
79 |
+
"""
|
80 |
+
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
|
81 |
+
classes.sort()
|
82 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
83 |
+
return classes, class_to_idx
|
84 |
+
|
85 |
+
def __getitem__(self, index):
|
86 |
+
|
87 |
+
sample_dict = {}
|
88 |
+
for mod in ['caption', 'rgb']:
|
89 |
+
path, _ = self.samples[mod][index]
|
90 |
+
sample = self.modality_transforms[get_transform_key(mod)].load(path)
|
91 |
+
sample_dict[mod] = sample
|
92 |
+
|
93 |
+
if self.transform is not None:
|
94 |
+
sample_dict = self.transform(sample_dict)
|
95 |
+
|
96 |
+
return sample_dict
|
97 |
+
|
98 |
+
def __len__(self) -> int:
|
99 |
+
return len(list(self.samples.values())[0])
|
fourm/utils/generation_datasets/parti_prompts_dataset.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 EPFL and Apple Inc.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import pandas as pd
|
16 |
+
import numpy as np
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
|
19 |
+
|
20 |
+
class PartiPromptsDataset(Dataset):
|
21 |
+
"""
|
22 |
+
Parti Prompts caption dataset.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
text_tokenizer (tokenizers.Tokenizer): The tokenizer to use for encoding the captions.
|
26 |
+
max_length (int): The maximum sequence length of the captions.
|
27 |
+
parti_prompts_csv (str): The path to the Parti Prompts dataset.
|
28 |
+
"""
|
29 |
+
def __init__(self, text_tokenizer, max_length=128, parti_prompts_csv='fourm/utils/generation_datasets/PartiPrompts.tsv', parti_prompts_t5_embs=None, llm_embedder=None):
|
30 |
+
self.text_tokenizer = text_tokenizer
|
31 |
+
self.max_length = max_length
|
32 |
+
self.parti_prompts = pd.read_csv(parti_prompts_csv, sep='\t')
|
33 |
+
|
34 |
+
self.pad_id = text_tokenizer.token_to_id("[PAD]")
|
35 |
+
self.eos_id = text_tokenizer.token_to_id("[EOS]")
|
36 |
+
if parti_prompts_t5_embs is not None:
|
37 |
+
# T5 Embeddings are saved as a numpy array, so we need to load it
|
38 |
+
self.t5_embs = np.load(parti_prompts_t5_embs)['emb']
|
39 |
+
self.t5_masks = np.load(parti_prompts_t5_embs)['mask_valid']
|
40 |
+
self.llm_embedder = None
|
41 |
+
elif llm_embedder is not None:
|
42 |
+
self.t5_embs = None
|
43 |
+
self.llm_embedder = llm_embedder
|
44 |
+
else:
|
45 |
+
self.t5_embs = None
|
46 |
+
self.llm_embedder = None
|
47 |
+
|
48 |
+
def __getitem__(self, index):
|
49 |
+
text = self.parti_prompts.Prompt[index]
|
50 |
+
seq_ids = self.text_tokenizer.encode(text).ids + [self.eos_id]
|
51 |
+
|
52 |
+
tensor = torch.ones(self.max_length, dtype=torch.int) * self.pad_id
|
53 |
+
tensor[:len(seq_ids)] = torch.tensor(seq_ids, dtype=torch.int)
|
54 |
+
|
55 |
+
out = {}
|
56 |
+
out['caption'] = {'tensor': tensor}
|
57 |
+
|
58 |
+
if self.t5_embs is not None:
|
59 |
+
|
60 |
+
t5_emb = torch.tensor(self.t5_embs[index], dtype=torch.float32)
|
61 |
+
t5_emb = pad_or_truncate(t5_emb, self.max_length)
|
62 |
+
|
63 |
+
t5_mask = torch.tensor(self.t5_masks[index], dtype=torch.bool)
|
64 |
+
t5_mask = pad_or_truncate(t5_mask, self.max_length)
|
65 |
+
|
66 |
+
ascii_tensor = text_to_tensor(text, max_length=self.max_length * 10) # Save ASCII as tensor
|
67 |
+
|
68 |
+
out['t5_caption'] = {
|
69 |
+
'tensor': t5_emb,
|
70 |
+
'mask_valid': t5_mask,
|
71 |
+
'ascii_tensor': ascii_tensor,
|
72 |
+
}
|
73 |
+
elif self.llm_embedder is not None:
|
74 |
+
t5_emb, _, t5_mask = self.llm_embedder.get_text_embeddings([text])
|
75 |
+
t5_emb = pad_or_truncate(t5_emb.squeeze(0), self.max_length)
|
76 |
+
t5_mask = pad_or_truncate(t5_mask.bool().squeeze(0), self.max_length)
|
77 |
+
ascii_tensor = text_to_tensor(text, max_length=self.max_length * 10) # Save ASCII as tensor
|
78 |
+
|
79 |
+
out['t5_caption'] = {
|
80 |
+
'tensor': t5_emb,
|
81 |
+
'mask_valid': t5_mask,
|
82 |
+
'ascii_tensor': ascii_tensor,
|
83 |
+
}
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.parti_prompts)
|
89 |
+
|
90 |
+
|
91 |
+
def pad_or_truncate(tensor, fixed_length, padding_value=0):
|
92 |
+
current_length = tensor.shape[0]
|
93 |
+
|
94 |
+
if current_length < fixed_length:
|
95 |
+
# Calculate padding sizes for all dimensions, but only pad along dim=0
|
96 |
+
padding_sizes = [0] * 2 * len(tensor.shape)
|
97 |
+
padding_sizes[1] = fixed_length - current_length
|
98 |
+
return torch.nn.functional.pad(tensor, padding_sizes, 'constant', padding_value)
|
99 |
+
else:
|
100 |
+
return tensor[:fixed_length]
|
101 |
+
|
102 |
+
def text_to_tensor(text, max_length=None):
|
103 |
+
"""Converts plaintext to a tensor with optional padding."""
|
104 |
+
ascii_values = [ord(c) for c in text]
|
105 |
+
if max_length:
|
106 |
+
while len(ascii_values) < max_length:
|
107 |
+
ascii_values.append(0) # Using 0 as the padding value
|
108 |
+
return torch.tensor(ascii_values, dtype=torch.int)
|
109 |
+
|
110 |
+
|
111 |
+
def tensor_to_text(tensor):
|
112 |
+
"""Converts tensor back to plaintext. Assumes padding with zeros."""
|
113 |
+
ascii_values = tensor.tolist()
|
114 |
+
return ''.join(chr(val) for val in ascii_values if val != 0)
|
fourm/utils/hmr2_utils/hmr2/__init__.py
ADDED
File without changes
|
fourm/utils/hmr2_utils/hmr2/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .smpl_wrapper import SMPL
|
2 |
+
from .hmr2 import HMR2
|
fourm/utils/hmr2_utils/hmr2/models/backbones/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the ViTPose and 4DHumans code bases
|
3 |
+
# https://github.com/ViTAE-Transformer/ViTPose/
|
4 |
+
# https://github.com/shubham-goel/4D-Humans
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
from .vit import vit
|
8 |
+
|
9 |
+
def create_backbone(cfg):
|
10 |
+
if cfg.MODEL.BACKBONE.TYPE == 'vit':
|
11 |
+
return vit(cfg)
|
12 |
+
else:
|
13 |
+
raise NotImplementedError('Backbone type is not implemented')
|
fourm/utils/hmr2_utils/hmr2/models/backbones/vit.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the ViTPose and 4DHumans code bases
|
3 |
+
# https://github.com/ViTAE-Transformer/ViTPose/
|
4 |
+
# https://github.com/shubham-goel/4D-Humans
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from functools import partial
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.utils.checkpoint as checkpoint
|
14 |
+
|
15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
16 |
+
|
17 |
+
def vit(cfg):
|
18 |
+
return ViT(
|
19 |
+
img_size=(256, 192),
|
20 |
+
patch_size=16,
|
21 |
+
embed_dim=1280,
|
22 |
+
depth=32,
|
23 |
+
num_heads=16,
|
24 |
+
ratio=1,
|
25 |
+
use_checkpoint=False,
|
26 |
+
mlp_ratio=4,
|
27 |
+
qkv_bias=True,
|
28 |
+
drop_path_rate=0.55,
|
29 |
+
)
|
30 |
+
|
31 |
+
def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
|
32 |
+
"""
|
33 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
34 |
+
dimension for the original embeddings.
|
35 |
+
Args:
|
36 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
37 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
38 |
+
hw (Tuple): size of input image tokens.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
42 |
+
"""
|
43 |
+
cls_token = None
|
44 |
+
B, L, C = abs_pos.shape
|
45 |
+
if has_cls_token:
|
46 |
+
cls_token = abs_pos[:, 0:1]
|
47 |
+
abs_pos = abs_pos[:, 1:]
|
48 |
+
|
49 |
+
if ori_h != h or ori_w != w:
|
50 |
+
new_abs_pos = F.interpolate(
|
51 |
+
abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
|
52 |
+
size=(h, w),
|
53 |
+
mode="bicubic",
|
54 |
+
align_corners=False,
|
55 |
+
).permute(0, 2, 3, 1).reshape(B, -1, C)
|
56 |
+
|
57 |
+
else:
|
58 |
+
new_abs_pos = abs_pos
|
59 |
+
|
60 |
+
if cls_token is not None:
|
61 |
+
new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
|
62 |
+
return new_abs_pos
|
63 |
+
|
64 |
+
class DropPath(nn.Module):
|
65 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
66 |
+
"""
|
67 |
+
def __init__(self, drop_prob=None):
|
68 |
+
super(DropPath, self).__init__()
|
69 |
+
self.drop_prob = drop_prob
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return drop_path(x, self.drop_prob, self.training)
|
73 |
+
|
74 |
+
def extra_repr(self):
|
75 |
+
return 'p={}'.format(self.drop_prob)
|
76 |
+
|
77 |
+
class Mlp(nn.Module):
|
78 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
79 |
+
super().__init__()
|
80 |
+
out_features = out_features or in_features
|
81 |
+
hidden_features = hidden_features or in_features
|
82 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
83 |
+
self.act = act_layer()
|
84 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
85 |
+
self.drop = nn.Dropout(drop)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.fc1(x)
|
89 |
+
x = self.act(x)
|
90 |
+
x = self.fc2(x)
|
91 |
+
x = self.drop(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
class Attention(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
97 |
+
proj_drop=0., attn_head_dim=None,):
|
98 |
+
super().__init__()
|
99 |
+
self.num_heads = num_heads
|
100 |
+
head_dim = dim // num_heads
|
101 |
+
self.dim = dim
|
102 |
+
|
103 |
+
if attn_head_dim is not None:
|
104 |
+
head_dim = attn_head_dim
|
105 |
+
all_head_dim = head_dim * self.num_heads
|
106 |
+
|
107 |
+
self.scale = qk_scale or head_dim ** -0.5
|
108 |
+
|
109 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
|
110 |
+
|
111 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
112 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
113 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
B, N, C = x.shape
|
117 |
+
qkv = self.qkv(x)
|
118 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
119 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
120 |
+
|
121 |
+
q = q * self.scale
|
122 |
+
attn = (q @ k.transpose(-2, -1))
|
123 |
+
|
124 |
+
attn = attn.softmax(dim=-1)
|
125 |
+
attn = self.attn_drop(attn)
|
126 |
+
|
127 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
128 |
+
x = self.proj(x)
|
129 |
+
x = self.proj_drop(x)
|
130 |
+
|
131 |
+
return x
|
132 |
+
|
133 |
+
class Block(nn.Module):
|
134 |
+
|
135 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
|
136 |
+
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
|
137 |
+
norm_layer=nn.LayerNorm, attn_head_dim=None
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
self.norm1 = norm_layer(dim)
|
142 |
+
self.attn = Attention(
|
143 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
144 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
|
145 |
+
)
|
146 |
+
|
147 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
148 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
149 |
+
self.norm2 = norm_layer(dim)
|
150 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
151 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
155 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class PatchEmbed(nn.Module):
|
160 |
+
""" Image to Patch Embedding
|
161 |
+
"""
|
162 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
|
163 |
+
super().__init__()
|
164 |
+
img_size = to_2tuple(img_size)
|
165 |
+
patch_size = to_2tuple(patch_size)
|
166 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
|
167 |
+
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
|
168 |
+
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
|
169 |
+
self.img_size = img_size
|
170 |
+
self.patch_size = patch_size
|
171 |
+
self.num_patches = num_patches
|
172 |
+
|
173 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
|
174 |
+
|
175 |
+
def forward(self, x, **kwargs):
|
176 |
+
B, C, H, W = x.shape
|
177 |
+
x = self.proj(x)
|
178 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
179 |
+
|
180 |
+
x = x.flatten(2).transpose(1, 2)
|
181 |
+
return x, (Hp, Wp)
|
182 |
+
|
183 |
+
|
184 |
+
class HybridEmbed(nn.Module):
|
185 |
+
""" CNN Feature Map Embedding
|
186 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
187 |
+
"""
|
188 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
189 |
+
super().__init__()
|
190 |
+
assert isinstance(backbone, nn.Module)
|
191 |
+
img_size = to_2tuple(img_size)
|
192 |
+
self.img_size = img_size
|
193 |
+
self.backbone = backbone
|
194 |
+
if feature_size is None:
|
195 |
+
with torch.no_grad():
|
196 |
+
training = backbone.training
|
197 |
+
if training:
|
198 |
+
backbone.eval()
|
199 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
200 |
+
feature_size = o.shape[-2:]
|
201 |
+
feature_dim = o.shape[1]
|
202 |
+
backbone.train(training)
|
203 |
+
else:
|
204 |
+
feature_size = to_2tuple(feature_size)
|
205 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
206 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
207 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
208 |
+
|
209 |
+
def forward(self, x):
|
210 |
+
x = self.backbone(x)[-1]
|
211 |
+
x = x.flatten(2).transpose(1, 2)
|
212 |
+
x = self.proj(x)
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
class ViT(nn.Module):
|
217 |
+
|
218 |
+
def __init__(self,
|
219 |
+
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
|
220 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
221 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
|
222 |
+
frozen_stages=-1, ratio=1, last_norm=True,
|
223 |
+
patch_padding='pad', freeze_attn=False, freeze_ffn=False,
|
224 |
+
):
|
225 |
+
# Protect mutable default arguments
|
226 |
+
super(ViT, self).__init__()
|
227 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
228 |
+
self.num_classes = num_classes
|
229 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
230 |
+
self.frozen_stages = frozen_stages
|
231 |
+
self.use_checkpoint = use_checkpoint
|
232 |
+
self.patch_padding = patch_padding
|
233 |
+
self.freeze_attn = freeze_attn
|
234 |
+
self.freeze_ffn = freeze_ffn
|
235 |
+
self.depth = depth
|
236 |
+
|
237 |
+
if hybrid_backbone is not None:
|
238 |
+
self.patch_embed = HybridEmbed(
|
239 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
240 |
+
else:
|
241 |
+
self.patch_embed = PatchEmbed(
|
242 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
|
243 |
+
num_patches = self.patch_embed.num_patches
|
244 |
+
|
245 |
+
# since the pretraining model has class token
|
246 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
247 |
+
|
248 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
249 |
+
|
250 |
+
self.blocks = nn.ModuleList([
|
251 |
+
Block(
|
252 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
253 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
254 |
+
)
|
255 |
+
for i in range(depth)])
|
256 |
+
|
257 |
+
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
|
258 |
+
|
259 |
+
if self.pos_embed is not None:
|
260 |
+
trunc_normal_(self.pos_embed, std=.02)
|
261 |
+
|
262 |
+
self._freeze_stages()
|
263 |
+
|
264 |
+
def _freeze_stages(self):
|
265 |
+
"""Freeze parameters."""
|
266 |
+
if self.frozen_stages >= 0:
|
267 |
+
self.patch_embed.eval()
|
268 |
+
for param in self.patch_embed.parameters():
|
269 |
+
param.requires_grad = False
|
270 |
+
|
271 |
+
for i in range(1, self.frozen_stages + 1):
|
272 |
+
m = self.blocks[i]
|
273 |
+
m.eval()
|
274 |
+
for param in m.parameters():
|
275 |
+
param.requires_grad = False
|
276 |
+
|
277 |
+
if self.freeze_attn:
|
278 |
+
for i in range(0, self.depth):
|
279 |
+
m = self.blocks[i]
|
280 |
+
m.attn.eval()
|
281 |
+
m.norm1.eval()
|
282 |
+
for param in m.attn.parameters():
|
283 |
+
param.requires_grad = False
|
284 |
+
for param in m.norm1.parameters():
|
285 |
+
param.requires_grad = False
|
286 |
+
|
287 |
+
if self.freeze_ffn:
|
288 |
+
self.pos_embed.requires_grad = False
|
289 |
+
self.patch_embed.eval()
|
290 |
+
for param in self.patch_embed.parameters():
|
291 |
+
param.requires_grad = False
|
292 |
+
for i in range(0, self.depth):
|
293 |
+
m = self.blocks[i]
|
294 |
+
m.mlp.eval()
|
295 |
+
m.norm2.eval()
|
296 |
+
for param in m.mlp.parameters():
|
297 |
+
param.requires_grad = False
|
298 |
+
for param in m.norm2.parameters():
|
299 |
+
param.requires_grad = False
|
300 |
+
|
301 |
+
def init_weights(self):
|
302 |
+
"""Initialize the weights in backbone.
|
303 |
+
Args:
|
304 |
+
pretrained (str, optional): Path to pre-trained weights.
|
305 |
+
Defaults to None.
|
306 |
+
"""
|
307 |
+
def _init_weights(m):
|
308 |
+
if isinstance(m, nn.Linear):
|
309 |
+
trunc_normal_(m.weight, std=.02)
|
310 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
311 |
+
nn.init.constant_(m.bias, 0)
|
312 |
+
elif isinstance(m, nn.LayerNorm):
|
313 |
+
nn.init.constant_(m.bias, 0)
|
314 |
+
nn.init.constant_(m.weight, 1.0)
|
315 |
+
|
316 |
+
self.apply(_init_weights)
|
317 |
+
|
318 |
+
def get_num_layers(self):
|
319 |
+
return len(self.blocks)
|
320 |
+
|
321 |
+
@torch.jit.ignore
|
322 |
+
def no_weight_decay(self):
|
323 |
+
return {'pos_embed', 'cls_token'}
|
324 |
+
|
325 |
+
def forward_features(self, x):
|
326 |
+
B, C, H, W = x.shape
|
327 |
+
x, (Hp, Wp) = self.patch_embed(x)
|
328 |
+
|
329 |
+
if self.pos_embed is not None:
|
330 |
+
# fit for multiple GPU training
|
331 |
+
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
|
332 |
+
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
|
333 |
+
|
334 |
+
for blk in self.blocks:
|
335 |
+
if self.use_checkpoint:
|
336 |
+
x = checkpoint.checkpoint(blk, x)
|
337 |
+
else:
|
338 |
+
x = blk(x)
|
339 |
+
|
340 |
+
x = self.last_norm(x)
|
341 |
+
|
342 |
+
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
|
343 |
+
|
344 |
+
return xp
|
345 |
+
|
346 |
+
def forward(self, x):
|
347 |
+
x = self.forward_features(x)
|
348 |
+
return x
|
349 |
+
|
350 |
+
def train(self, mode=True):
|
351 |
+
"""Convert the model into training mode."""
|
352 |
+
super().train(mode)
|
353 |
+
self._freeze_stages()
|
fourm/utils/hmr2_utils/hmr2/models/components/__init__.py
ADDED
File without changes
|
fourm/utils/hmr2_utils/hmr2/models/components/pose_transformer.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans code base
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
from inspect import isfunction
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from einops import rearrange
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from .t_cond_mlp import (
|
15 |
+
AdaptiveLayerNorm1D,
|
16 |
+
FrequencyEmbedder,
|
17 |
+
normalization_layer,
|
18 |
+
)
|
19 |
+
# from .vit import Attention, FeedForward
|
20 |
+
|
21 |
+
|
22 |
+
def exists(val):
|
23 |
+
return val is not None
|
24 |
+
|
25 |
+
|
26 |
+
def default(val, d):
|
27 |
+
if exists(val):
|
28 |
+
return val
|
29 |
+
return d() if isfunction(d) else d
|
30 |
+
|
31 |
+
|
32 |
+
class PreNorm(nn.Module):
|
33 |
+
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
|
34 |
+
super().__init__()
|
35 |
+
self.norm = normalization_layer(norm, dim, norm_cond_dim)
|
36 |
+
self.fn = fn
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
39 |
+
if isinstance(self.norm, AdaptiveLayerNorm1D):
|
40 |
+
return self.fn(self.norm(x, *args), **kwargs)
|
41 |
+
else:
|
42 |
+
return self.fn(self.norm(x), **kwargs)
|
43 |
+
|
44 |
+
|
45 |
+
class FeedForward(nn.Module):
|
46 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
47 |
+
super().__init__()
|
48 |
+
self.net = nn.Sequential(
|
49 |
+
nn.Linear(dim, hidden_dim),
|
50 |
+
nn.GELU(),
|
51 |
+
nn.Dropout(dropout),
|
52 |
+
nn.Linear(hidden_dim, dim),
|
53 |
+
nn.Dropout(dropout),
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
return self.net(x)
|
58 |
+
|
59 |
+
|
60 |
+
class Attention(nn.Module):
|
61 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
62 |
+
super().__init__()
|
63 |
+
inner_dim = dim_head * heads
|
64 |
+
project_out = not (heads == 1 and dim_head == dim)
|
65 |
+
|
66 |
+
self.heads = heads
|
67 |
+
self.scale = dim_head**-0.5
|
68 |
+
|
69 |
+
self.attend = nn.Softmax(dim=-1)
|
70 |
+
self.dropout = nn.Dropout(dropout)
|
71 |
+
|
72 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
73 |
+
|
74 |
+
self.to_out = (
|
75 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
76 |
+
if project_out
|
77 |
+
else nn.Identity()
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
82 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
|
83 |
+
|
84 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
85 |
+
|
86 |
+
attn = self.attend(dots)
|
87 |
+
attn = self.dropout(attn)
|
88 |
+
|
89 |
+
out = torch.matmul(attn, v)
|
90 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
91 |
+
return self.to_out(out)
|
92 |
+
|
93 |
+
|
94 |
+
class CrossAttention(nn.Module):
|
95 |
+
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
96 |
+
super().__init__()
|
97 |
+
inner_dim = dim_head * heads
|
98 |
+
project_out = not (heads == 1 and dim_head == dim)
|
99 |
+
|
100 |
+
self.heads = heads
|
101 |
+
self.scale = dim_head**-0.5
|
102 |
+
|
103 |
+
self.attend = nn.Softmax(dim=-1)
|
104 |
+
self.dropout = nn.Dropout(dropout)
|
105 |
+
|
106 |
+
context_dim = default(context_dim, dim)
|
107 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
|
108 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
109 |
+
|
110 |
+
self.to_out = (
|
111 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
112 |
+
if project_out
|
113 |
+
else nn.Identity()
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x, context=None):
|
117 |
+
context = default(context, x)
|
118 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
119 |
+
q = self.to_q(x)
|
120 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
|
121 |
+
|
122 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
123 |
+
|
124 |
+
attn = self.attend(dots)
|
125 |
+
attn = self.dropout(attn)
|
126 |
+
|
127 |
+
out = torch.matmul(attn, v)
|
128 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
129 |
+
return self.to_out(out)
|
130 |
+
|
131 |
+
|
132 |
+
class Transformer(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
dim: int,
|
136 |
+
depth: int,
|
137 |
+
heads: int,
|
138 |
+
dim_head: int,
|
139 |
+
mlp_dim: int,
|
140 |
+
dropout: float = 0.0,
|
141 |
+
norm: str = "layer",
|
142 |
+
norm_cond_dim: int = -1,
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
self.layers = nn.ModuleList([])
|
146 |
+
for _ in range(depth):
|
147 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
148 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
149 |
+
self.layers.append(
|
150 |
+
nn.ModuleList(
|
151 |
+
[
|
152 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
153 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
154 |
+
]
|
155 |
+
)
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, x: torch.Tensor, *args):
|
159 |
+
for attn, ff in self.layers:
|
160 |
+
x = attn(x, *args) + x
|
161 |
+
x = ff(x, *args) + x
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class TransformerCrossAttn(nn.Module):
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
dim: int,
|
169 |
+
depth: int,
|
170 |
+
heads: int,
|
171 |
+
dim_head: int,
|
172 |
+
mlp_dim: int,
|
173 |
+
dropout: float = 0.0,
|
174 |
+
norm: str = "layer",
|
175 |
+
norm_cond_dim: int = -1,
|
176 |
+
context_dim: Optional[int] = None,
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
self.layers = nn.ModuleList([])
|
180 |
+
for _ in range(depth):
|
181 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
182 |
+
ca = CrossAttention(
|
183 |
+
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
|
184 |
+
)
|
185 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
186 |
+
self.layers.append(
|
187 |
+
nn.ModuleList(
|
188 |
+
[
|
189 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
190 |
+
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
|
191 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
192 |
+
]
|
193 |
+
)
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
|
197 |
+
if context_list is None:
|
198 |
+
context_list = [context] * len(self.layers)
|
199 |
+
if len(context_list) != len(self.layers):
|
200 |
+
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
|
201 |
+
|
202 |
+
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
|
203 |
+
x = self_attn(x, *args) + x
|
204 |
+
x = cross_attn(x, *args, context=context_list[i]) + x
|
205 |
+
x = ff(x, *args) + x
|
206 |
+
return x
|
207 |
+
|
208 |
+
|
209 |
+
class DropTokenDropout(nn.Module):
|
210 |
+
def __init__(self, p: float = 0.1):
|
211 |
+
super().__init__()
|
212 |
+
if p < 0 or p > 1:
|
213 |
+
raise ValueError(
|
214 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
215 |
+
)
|
216 |
+
self.p = p
|
217 |
+
|
218 |
+
def forward(self, x: torch.Tensor):
|
219 |
+
# x: (batch_size, seq_len, dim)
|
220 |
+
if self.training and self.p > 0:
|
221 |
+
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
|
222 |
+
# TODO: permutation idx for each batch using torch.argsort
|
223 |
+
if zero_mask.any():
|
224 |
+
x = x[:, ~zero_mask, :]
|
225 |
+
return x
|
226 |
+
|
227 |
+
|
228 |
+
class ZeroTokenDropout(nn.Module):
|
229 |
+
def __init__(self, p: float = 0.1):
|
230 |
+
super().__init__()
|
231 |
+
if p < 0 or p > 1:
|
232 |
+
raise ValueError(
|
233 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
234 |
+
)
|
235 |
+
self.p = p
|
236 |
+
|
237 |
+
def forward(self, x: torch.Tensor):
|
238 |
+
# x: (batch_size, seq_len, dim)
|
239 |
+
if self.training and self.p > 0:
|
240 |
+
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
|
241 |
+
# Zero-out the masked tokens
|
242 |
+
x[zero_mask, :] = 0
|
243 |
+
return x
|
244 |
+
|
245 |
+
|
246 |
+
class TransformerEncoder(nn.Module):
|
247 |
+
def __init__(
|
248 |
+
self,
|
249 |
+
num_tokens: int,
|
250 |
+
token_dim: int,
|
251 |
+
dim: int,
|
252 |
+
depth: int,
|
253 |
+
heads: int,
|
254 |
+
mlp_dim: int,
|
255 |
+
dim_head: int = 64,
|
256 |
+
dropout: float = 0.0,
|
257 |
+
emb_dropout: float = 0.0,
|
258 |
+
emb_dropout_type: str = "drop",
|
259 |
+
emb_dropout_loc: str = "token",
|
260 |
+
norm: str = "layer",
|
261 |
+
norm_cond_dim: int = -1,
|
262 |
+
token_pe_numfreq: int = -1,
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
if token_pe_numfreq > 0:
|
266 |
+
token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
|
267 |
+
self.to_token_embedding = nn.Sequential(
|
268 |
+
Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
|
269 |
+
FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
|
270 |
+
Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
|
271 |
+
nn.Linear(token_dim_new, dim),
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
275 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
276 |
+
if emb_dropout_type == "drop":
|
277 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
278 |
+
elif emb_dropout_type == "zero":
|
279 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
280 |
+
else:
|
281 |
+
raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
|
282 |
+
self.emb_dropout_loc = emb_dropout_loc
|
283 |
+
|
284 |
+
self.transformer = Transformer(
|
285 |
+
dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
|
286 |
+
)
|
287 |
+
|
288 |
+
def forward(self, inp: torch.Tensor, *args, **kwargs):
|
289 |
+
x = inp
|
290 |
+
|
291 |
+
if self.emb_dropout_loc == "input":
|
292 |
+
x = self.dropout(x)
|
293 |
+
x = self.to_token_embedding(x)
|
294 |
+
|
295 |
+
if self.emb_dropout_loc == "token":
|
296 |
+
x = self.dropout(x)
|
297 |
+
b, n, _ = x.shape
|
298 |
+
x += self.pos_embedding[:, :n]
|
299 |
+
|
300 |
+
if self.emb_dropout_loc == "token_afterpos":
|
301 |
+
x = self.dropout(x)
|
302 |
+
x = self.transformer(x, *args)
|
303 |
+
return x
|
304 |
+
|
305 |
+
|
306 |
+
class TransformerDecoder(nn.Module):
|
307 |
+
def __init__(
|
308 |
+
self,
|
309 |
+
num_tokens: int,
|
310 |
+
token_dim: int,
|
311 |
+
dim: int,
|
312 |
+
depth: int,
|
313 |
+
heads: int,
|
314 |
+
mlp_dim: int,
|
315 |
+
dim_head: int = 64,
|
316 |
+
dropout: float = 0.0,
|
317 |
+
emb_dropout: float = 0.0,
|
318 |
+
emb_dropout_type: str = 'drop',
|
319 |
+
norm: str = "layer",
|
320 |
+
norm_cond_dim: int = -1,
|
321 |
+
context_dim: Optional[int] = None,
|
322 |
+
skip_token_embedding: bool = False,
|
323 |
+
):
|
324 |
+
super().__init__()
|
325 |
+
if not skip_token_embedding:
|
326 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
327 |
+
else:
|
328 |
+
self.to_token_embedding = nn.Identity()
|
329 |
+
if token_dim != dim:
|
330 |
+
raise ValueError(
|
331 |
+
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
|
332 |
+
)
|
333 |
+
|
334 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
335 |
+
if emb_dropout_type == "drop":
|
336 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
337 |
+
elif emb_dropout_type == "zero":
|
338 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
339 |
+
elif emb_dropout_type == "normal":
|
340 |
+
self.dropout = nn.Dropout(emb_dropout)
|
341 |
+
|
342 |
+
self.transformer = TransformerCrossAttn(
|
343 |
+
dim,
|
344 |
+
depth,
|
345 |
+
heads,
|
346 |
+
dim_head,
|
347 |
+
mlp_dim,
|
348 |
+
dropout,
|
349 |
+
norm=norm,
|
350 |
+
norm_cond_dim=norm_cond_dim,
|
351 |
+
context_dim=context_dim,
|
352 |
+
)
|
353 |
+
|
354 |
+
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
|
355 |
+
x = self.to_token_embedding(inp)
|
356 |
+
b, n, _ = x.shape
|
357 |
+
|
358 |
+
x = self.dropout(x)
|
359 |
+
x += self.pos_embedding[:, :n]
|
360 |
+
|
361 |
+
x = self.transformer(x, *args, context=context, context_list=context_list)
|
362 |
+
return x
|
363 |
+
|
fourm/utils/hmr2_utils/hmr2/models/components/t_cond_mlp.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans code base
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import copy
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class AdaptiveLayerNorm1D(torch.nn.Module):
|
13 |
+
def __init__(self, data_dim: int, norm_cond_dim: int):
|
14 |
+
super().__init__()
|
15 |
+
if data_dim <= 0:
|
16 |
+
raise ValueError(f"data_dim must be positive, but got {data_dim}")
|
17 |
+
if norm_cond_dim <= 0:
|
18 |
+
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
|
19 |
+
self.norm = torch.nn.LayerNorm(
|
20 |
+
data_dim
|
21 |
+
) # TODO: Check if elementwise_affine=True is correct
|
22 |
+
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
|
23 |
+
torch.nn.init.zeros_(self.linear.weight)
|
24 |
+
torch.nn.init.zeros_(self.linear.bias)
|
25 |
+
|
26 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
27 |
+
# x: (batch, ..., data_dim)
|
28 |
+
# t: (batch, norm_cond_dim)
|
29 |
+
# return: (batch, data_dim)
|
30 |
+
x = self.norm(x)
|
31 |
+
alpha, beta = self.linear(t).chunk(2, dim=-1)
|
32 |
+
|
33 |
+
# Add singleton dimensions to alpha and beta
|
34 |
+
if x.dim() > 2:
|
35 |
+
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
|
36 |
+
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
|
37 |
+
|
38 |
+
return x * (1 + alpha) + beta
|
39 |
+
|
40 |
+
|
41 |
+
class SequentialCond(torch.nn.Sequential):
|
42 |
+
def forward(self, input, *args, **kwargs):
|
43 |
+
for module in self:
|
44 |
+
if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
|
45 |
+
# print(f'Passing on args to {module}', [a.shape for a in args])
|
46 |
+
input = module(input, *args, **kwargs)
|
47 |
+
else:
|
48 |
+
# print(f'Skipping passing args to {module}', [a.shape for a in args])
|
49 |
+
input = module(input)
|
50 |
+
return input
|
51 |
+
|
52 |
+
|
53 |
+
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
|
54 |
+
if norm == "batch":
|
55 |
+
return torch.nn.BatchNorm1d(dim)
|
56 |
+
elif norm == "layer":
|
57 |
+
return torch.nn.LayerNorm(dim)
|
58 |
+
elif norm == "ada":
|
59 |
+
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
|
60 |
+
return AdaptiveLayerNorm1D(dim, norm_cond_dim)
|
61 |
+
elif norm is None:
|
62 |
+
return torch.nn.Identity()
|
63 |
+
else:
|
64 |
+
raise ValueError(f"Unknown norm: {norm}")
|
65 |
+
|
66 |
+
|
67 |
+
def linear_norm_activ_dropout(
|
68 |
+
input_dim: int,
|
69 |
+
output_dim: int,
|
70 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
71 |
+
bias: bool = True,
|
72 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
73 |
+
dropout: float = 0.0,
|
74 |
+
norm_cond_dim: int = -1,
|
75 |
+
) -> SequentialCond:
|
76 |
+
layers = []
|
77 |
+
layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
|
78 |
+
if norm is not None:
|
79 |
+
layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
|
80 |
+
layers.append(copy.deepcopy(activation))
|
81 |
+
if dropout > 0.0:
|
82 |
+
layers.append(torch.nn.Dropout(dropout))
|
83 |
+
return SequentialCond(*layers)
|
84 |
+
|
85 |
+
|
86 |
+
def create_simple_mlp(
|
87 |
+
input_dim: int,
|
88 |
+
hidden_dims: List[int],
|
89 |
+
output_dim: int,
|
90 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
91 |
+
bias: bool = True,
|
92 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
93 |
+
dropout: float = 0.0,
|
94 |
+
norm_cond_dim: int = -1,
|
95 |
+
) -> SequentialCond:
|
96 |
+
layers = []
|
97 |
+
prev_dim = input_dim
|
98 |
+
for hidden_dim in hidden_dims:
|
99 |
+
layers.extend(
|
100 |
+
linear_norm_activ_dropout(
|
101 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
102 |
+
)
|
103 |
+
)
|
104 |
+
prev_dim = hidden_dim
|
105 |
+
layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
|
106 |
+
return SequentialCond(*layers)
|
107 |
+
|
108 |
+
|
109 |
+
class ResidualMLPBlock(torch.nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
input_dim: int,
|
113 |
+
hidden_dim: int,
|
114 |
+
num_hidden_layers: int,
|
115 |
+
output_dim: int,
|
116 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
117 |
+
bias: bool = True,
|
118 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
119 |
+
dropout: float = 0.0,
|
120 |
+
norm_cond_dim: int = -1,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
if not (input_dim == output_dim == hidden_dim):
|
124 |
+
raise NotImplementedError(
|
125 |
+
f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
|
126 |
+
)
|
127 |
+
|
128 |
+
layers = []
|
129 |
+
prev_dim = input_dim
|
130 |
+
for i in range(num_hidden_layers):
|
131 |
+
layers.append(
|
132 |
+
linear_norm_activ_dropout(
|
133 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
134 |
+
)
|
135 |
+
)
|
136 |
+
prev_dim = hidden_dim
|
137 |
+
self.model = SequentialCond(*layers)
|
138 |
+
self.skip = torch.nn.Identity()
|
139 |
+
|
140 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
141 |
+
return x + self.model(x, *args, **kwargs)
|
142 |
+
|
143 |
+
|
144 |
+
class ResidualMLP(torch.nn.Module):
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
input_dim: int,
|
148 |
+
hidden_dim: int,
|
149 |
+
num_hidden_layers: int,
|
150 |
+
output_dim: int,
|
151 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
152 |
+
bias: bool = True,
|
153 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
154 |
+
dropout: float = 0.0,
|
155 |
+
num_blocks: int = 1,
|
156 |
+
norm_cond_dim: int = -1,
|
157 |
+
):
|
158 |
+
super().__init__()
|
159 |
+
self.input_dim = input_dim
|
160 |
+
self.model = SequentialCond(
|
161 |
+
linear_norm_activ_dropout(
|
162 |
+
input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
163 |
+
),
|
164 |
+
*[
|
165 |
+
ResidualMLPBlock(
|
166 |
+
hidden_dim,
|
167 |
+
hidden_dim,
|
168 |
+
num_hidden_layers,
|
169 |
+
hidden_dim,
|
170 |
+
activation,
|
171 |
+
bias,
|
172 |
+
norm,
|
173 |
+
dropout,
|
174 |
+
norm_cond_dim,
|
175 |
+
)
|
176 |
+
for _ in range(num_blocks)
|
177 |
+
],
|
178 |
+
torch.nn.Linear(hidden_dim, output_dim, bias=bias),
|
179 |
+
)
|
180 |
+
|
181 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
182 |
+
return self.model(x, *args, **kwargs)
|
183 |
+
|
184 |
+
|
185 |
+
class FrequencyEmbedder(torch.nn.Module):
|
186 |
+
def __init__(self, num_frequencies, max_freq_log2):
|
187 |
+
super().__init__()
|
188 |
+
frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
|
189 |
+
self.register_buffer("frequencies", frequencies)
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
# x should be of size (N,) or (N, D)
|
193 |
+
N = x.size(0)
|
194 |
+
if x.dim() == 1: # (N,)
|
195 |
+
x = x.unsqueeze(1) # (N, D) where D=1
|
196 |
+
x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
|
197 |
+
scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
|
198 |
+
s = torch.sin(scaled)
|
199 |
+
c = torch.cos(scaled)
|
200 |
+
embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
|
201 |
+
N, -1
|
202 |
+
) # (N, D * 2 * num_frequencies + D)
|
203 |
+
return embedded
|
204 |
+
|
fourm/utils/hmr2_utils/hmr2/models/heads/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .smpl_head import build_smpl_head
|
fourm/utils/hmr2_utils/hmr2/models/heads/smpl_head.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans code base
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
import einops
|
11 |
+
|
12 |
+
from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
|
13 |
+
from ..components.pose_transformer import TransformerDecoder
|
14 |
+
|
15 |
+
def build_smpl_head(cfg):
|
16 |
+
smpl_head_type = cfg.MODEL.SMPL_HEAD.get('TYPE', 'hmr')
|
17 |
+
if smpl_head_type == 'transformer_decoder':
|
18 |
+
return SMPLTransformerDecoderHead(cfg)
|
19 |
+
else:
|
20 |
+
raise ValueError('Unknown SMPL head type: {}'.format(smpl_head_type))
|
21 |
+
|
22 |
+
class SMPLTransformerDecoderHead(nn.Module):
|
23 |
+
""" Cross-attention based SMPL Transformer decoder
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, cfg):
|
27 |
+
super().__init__()
|
28 |
+
self.cfg = cfg
|
29 |
+
self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get('JOINT_REP', '6d')
|
30 |
+
self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
|
31 |
+
npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS + 1)
|
32 |
+
self.npose = npose
|
33 |
+
self.input_is_mean_shape = cfg.MODEL.SMPL_HEAD.get('TRANSFORMER_INPUT', 'zero') == 'mean_shape'
|
34 |
+
transformer_args = dict(
|
35 |
+
num_tokens=1,
|
36 |
+
token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1,
|
37 |
+
dim=1024,
|
38 |
+
)
|
39 |
+
transformer_args = (transformer_args | dict(cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER))
|
40 |
+
self.transformer = TransformerDecoder(
|
41 |
+
**transformer_args
|
42 |
+
)
|
43 |
+
dim=transformer_args['dim']
|
44 |
+
self.decpose = nn.Linear(dim, npose)
|
45 |
+
self.decshape = nn.Linear(dim, 10)
|
46 |
+
self.deccam = nn.Linear(dim, 3)
|
47 |
+
|
48 |
+
if cfg.MODEL.SMPL_HEAD.get('INIT_DECODER_XAVIER', False):
|
49 |
+
# True by default in MLP. False by default in Transformer
|
50 |
+
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
|
51 |
+
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
|
52 |
+
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
|
53 |
+
|
54 |
+
mean_params = np.load(cfg.SMPL.MEAN_PARAMS)
|
55 |
+
init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
|
56 |
+
init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
|
57 |
+
init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
|
58 |
+
self.register_buffer('init_body_pose', init_body_pose)
|
59 |
+
self.register_buffer('init_betas', init_betas)
|
60 |
+
self.register_buffer('init_cam', init_cam)
|
61 |
+
|
62 |
+
def forward(self, x, **kwargs):
|
63 |
+
|
64 |
+
batch_size = x.shape[0]
|
65 |
+
# vit pretrained backbone is channel-first. Change to token-first
|
66 |
+
x = einops.rearrange(x, 'b c h w -> b (h w) c')
|
67 |
+
|
68 |
+
init_body_pose = self.init_body_pose.expand(batch_size, -1)
|
69 |
+
init_betas = self.init_betas.expand(batch_size, -1)
|
70 |
+
init_cam = self.init_cam.expand(batch_size, -1)
|
71 |
+
|
72 |
+
# TODO: Convert init_body_pose to aa rep if needed
|
73 |
+
if self.joint_rep_type == 'aa':
|
74 |
+
raise NotImplementedError
|
75 |
+
|
76 |
+
pred_body_pose = init_body_pose
|
77 |
+
pred_betas = init_betas
|
78 |
+
pred_cam = init_cam
|
79 |
+
pred_body_pose_list = []
|
80 |
+
pred_betas_list = []
|
81 |
+
pred_cam_list = []
|
82 |
+
for i in range(self.cfg.MODEL.SMPL_HEAD.get('IEF_ITERS', 1)):
|
83 |
+
# Input token to transformer is zero token
|
84 |
+
if self.input_is_mean_shape:
|
85 |
+
token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[:,None,:]
|
86 |
+
else:
|
87 |
+
token = torch.zeros(batch_size, 1, 1).to(x.device)
|
88 |
+
|
89 |
+
# Pass through transformer
|
90 |
+
token_out = self.transformer(token, context=x)
|
91 |
+
token_out = token_out.squeeze(1) # (B, C)
|
92 |
+
|
93 |
+
# Readout from token_out
|
94 |
+
pred_body_pose = self.decpose(token_out) + pred_body_pose
|
95 |
+
pred_betas = self.decshape(token_out) + pred_betas
|
96 |
+
pred_cam = self.deccam(token_out) + pred_cam
|
97 |
+
pred_body_pose_list.append(pred_body_pose)
|
98 |
+
pred_betas_list.append(pred_betas)
|
99 |
+
pred_cam_list.append(pred_cam)
|
100 |
+
|
101 |
+
# Convert self.joint_rep_type -> rotmat
|
102 |
+
joint_conversion_fn = {
|
103 |
+
'6d': rot6d_to_rotmat,
|
104 |
+
'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
|
105 |
+
}[self.joint_rep_type]
|
106 |
+
|
107 |
+
pred_smpl_params_list = {}
|
108 |
+
pred_smpl_params_list['body_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_body_pose_list], dim=0)
|
109 |
+
pred_smpl_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
|
110 |
+
pred_smpl_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
|
111 |
+
pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, self.cfg.SMPL.NUM_BODY_JOINTS+1, 3, 3)
|
112 |
+
|
113 |
+
pred_smpl_params = {'global_orient': pred_body_pose[:, [0]],
|
114 |
+
'body_pose': pred_body_pose[:, 1:],
|
115 |
+
'betas': pred_betas}
|
116 |
+
return pred_smpl_params, pred_cam, pred_smpl_params_list
|
fourm/utils/hmr2_utils/hmr2/models/hmr2.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans code base
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# --------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from typing import Any, Dict, Mapping, Tuple
|
8 |
+
|
9 |
+
from yacs.config import CfgNode
|
10 |
+
|
11 |
+
from ..utils import SkeletonRenderer, MeshRenderer
|
12 |
+
from ..utils.geometry import perspective_projection
|
13 |
+
from .backbones import create_backbone
|
14 |
+
from .heads import build_smpl_head
|
15 |
+
from . import SMPL
|
16 |
+
|
17 |
+
class HMR2(torch.nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, cfg: CfgNode, init_renderer: bool = True):
|
20 |
+
"""
|
21 |
+
Setup HMR2 model
|
22 |
+
Args:
|
23 |
+
cfg (CfgNode): Config file as a yacs CfgNode
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
# Save hyperparameters
|
28 |
+
self.save_hyperparameters(logger=False, ignore=['init_renderer'])
|
29 |
+
|
30 |
+
self.cfg = cfg
|
31 |
+
# Create backbone feature extractor
|
32 |
+
self.backbone = create_backbone(cfg)
|
33 |
+
if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
|
34 |
+
self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'])
|
35 |
+
|
36 |
+
# Create SMPL head
|
37 |
+
self.smpl_head = build_smpl_head(cfg)
|
38 |
+
|
39 |
+
# Instantiate SMPL model
|
40 |
+
smpl_cfg = {k.lower(): v for k,v in dict(cfg.SMPL).items()}
|
41 |
+
self.smpl = SMPL(**smpl_cfg)
|
42 |
+
|
43 |
+
# Buffer that shows whetheer we need to initialize ActNorm layers
|
44 |
+
self.register_buffer('initialized', torch.tensor(False))
|
45 |
+
# Setup renderer for visualization
|
46 |
+
if init_renderer:
|
47 |
+
self.renderer = SkeletonRenderer(self.cfg)
|
48 |
+
self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smpl.faces)
|
49 |
+
else:
|
50 |
+
self.renderer = None
|
51 |
+
self.mesh_renderer = None
|
52 |
+
|
53 |
+
# Disable automatic optimization since we use adversarial training
|
54 |
+
self.automatic_optimization = False
|
55 |
+
|
56 |
+
def forward_step(self, batch: Dict, train: bool = False) -> Dict:
|
57 |
+
"""
|
58 |
+
Run a forward step of the network
|
59 |
+
Args:
|
60 |
+
batch (Dict): Dictionary containing batch data
|
61 |
+
train (bool): Flag indicating whether it is training or validation mode
|
62 |
+
Returns:
|
63 |
+
Dict: Dictionary containing the regression output
|
64 |
+
"""
|
65 |
+
|
66 |
+
# Use RGB image as input
|
67 |
+
x = batch['img']
|
68 |
+
batch_size = x.shape[0]
|
69 |
+
|
70 |
+
# Compute conditioning features using the backbone
|
71 |
+
# if using ViT backbone, we need to use a different aspect ratio
|
72 |
+
conditioning_feats = self.backbone(x[:,:,:,32:-32])
|
73 |
+
|
74 |
+
pred_smpl_params, pred_cam, _ = self.smpl_head(conditioning_feats)
|
75 |
+
|
76 |
+
# Store useful regression outputs to the output dict
|
77 |
+
output = {}
|
78 |
+
output['pred_cam'] = pred_cam
|
79 |
+
output['pred_smpl_params'] = {k: v.clone() for k,v in pred_smpl_params.items()}
|
80 |
+
|
81 |
+
# Compute camera translation
|
82 |
+
device = pred_smpl_params['body_pose'].device
|
83 |
+
dtype = pred_smpl_params['body_pose'].dtype
|
84 |
+
focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
|
85 |
+
pred_cam_t = torch.stack([pred_cam[:, 1],
|
86 |
+
pred_cam[:, 2],
|
87 |
+
2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
|
88 |
+
output['pred_cam_t'] = pred_cam_t
|
89 |
+
output['focal_length'] = focal_length
|
90 |
+
|
91 |
+
# Compute model vertices, joints and the projected joints
|
92 |
+
pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3)
|
93 |
+
pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3)
|
94 |
+
pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1)
|
95 |
+
smpl_output = self.smpl(**{k: v.float() for k,v in pred_smpl_params.items()}, pose2rot=False)
|
96 |
+
pred_keypoints_3d = smpl_output.joints
|
97 |
+
pred_vertices = smpl_output.vertices
|
98 |
+
output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
|
99 |
+
output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
|
100 |
+
pred_cam_t = pred_cam_t.reshape(-1, 3)
|
101 |
+
focal_length = focal_length.reshape(-1, 2)
|
102 |
+
pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
|
103 |
+
translation=pred_cam_t,
|
104 |
+
focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
|
105 |
+
|
106 |
+
output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
|
107 |
+
return output
|
108 |
+
|
109 |
+
def forward(self, batch: Dict) -> Dict:
|
110 |
+
"""
|
111 |
+
Run a forward step of the network in val mode
|
112 |
+
Args:
|
113 |
+
batch (Dict): Dictionary containing batch data
|
114 |
+
Returns:
|
115 |
+
Dict: Dictionary containing the regression output
|
116 |
+
"""
|
117 |
+
return self.forward_step(batch, train=False)
|
fourm/utils/hmr2_utils/hmr2/models/smpl_wrapper.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans and ProHMR code bases
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# https://github.com/nkolot/ProHMR
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pickle
|
10 |
+
from typing import Optional
|
11 |
+
import smplx
|
12 |
+
from smplx.lbs import vertices2joints
|
13 |
+
from smplx.utils import SMPLOutput
|
14 |
+
|
15 |
+
|
16 |
+
class SMPL(smplx.SMPLLayer):
|
17 |
+
def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
|
18 |
+
"""
|
19 |
+
Extension of the official SMPL implementation to support more joints.
|
20 |
+
Args:
|
21 |
+
Same as SMPLLayer.
|
22 |
+
joint_regressor_extra (str): Path to extra joint regressor.
|
23 |
+
"""
|
24 |
+
super(SMPL, self).__init__(*args, **kwargs)
|
25 |
+
smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
|
26 |
+
7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
|
27 |
+
|
28 |
+
if joint_regressor_extra is not None:
|
29 |
+
self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
|
30 |
+
self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long))
|
31 |
+
self.update_hips = update_hips
|
32 |
+
|
33 |
+
def forward(self, *args, **kwargs) -> SMPLOutput:
|
34 |
+
"""
|
35 |
+
Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
|
36 |
+
"""
|
37 |
+
smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
38 |
+
joints = smpl_output.joints[:, self.joint_map, :]
|
39 |
+
if self.update_hips:
|
40 |
+
joints[:,[9,12]] = joints[:,[9,12]] + \
|
41 |
+
0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \
|
42 |
+
0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]]))
|
43 |
+
if hasattr(self, 'joint_regressor_extra'):
|
44 |
+
extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
|
45 |
+
joints = torch.cat([joints, extra_joints], dim=1)
|
46 |
+
smpl_output.joints = joints
|
47 |
+
return smpl_output
|
fourm/utils/hmr2_utils/hmr2/utils/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans and ProHMR code bases
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# https://github.com/nkolot/ProHMR
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
from .renderer import Renderer
|
11 |
+
from .mesh_renderer import MeshRenderer
|
12 |
+
from .skeleton_renderer import SkeletonRenderer
|
13 |
+
# from .pose_utils import eval_pose, Evaluator
|
14 |
+
|
15 |
+
def recursive_to(x: Any, target: torch.device):
|
16 |
+
"""
|
17 |
+
Recursively transfer a batch of data to the target device
|
18 |
+
Args:
|
19 |
+
x (Any): Batch of data.
|
20 |
+
target (torch.device): Target device.
|
21 |
+
Returns:
|
22 |
+
Batch of data where all tensors are transfered to the target device.
|
23 |
+
"""
|
24 |
+
if isinstance(x, dict):
|
25 |
+
return {k: recursive_to(v, target) for k, v in x.items()}
|
26 |
+
elif isinstance(x, torch.Tensor):
|
27 |
+
return x.to(target)
|
28 |
+
elif isinstance(x, list):
|
29 |
+
return [recursive_to(i, target) for i in x]
|
30 |
+
else:
|
31 |
+
return x
|
fourm/utils/hmr2_utils/hmr2/utils/geometry.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans, ProHMR, and SPIN code bases
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# https://github.com/nkolot/ProHMR
|
5 |
+
# https://github.com/nkolot/SPIN
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from typing import Optional
|
9 |
+
import torch
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
def aa_to_rotmat(theta: torch.Tensor):
|
13 |
+
"""
|
14 |
+
Convert axis-angle representation to rotation matrix.
|
15 |
+
Works by first converting it to a quaternion.
|
16 |
+
Args:
|
17 |
+
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
|
18 |
+
Returns:
|
19 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
20 |
+
"""
|
21 |
+
norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
|
22 |
+
angle = torch.unsqueeze(norm, -1)
|
23 |
+
normalized = torch.div(theta, angle)
|
24 |
+
angle = angle * 0.5
|
25 |
+
v_cos = torch.cos(angle)
|
26 |
+
v_sin = torch.sin(angle)
|
27 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
|
28 |
+
return quat_to_rotmat(quat)
|
29 |
+
|
30 |
+
def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
|
31 |
+
"""
|
32 |
+
Convert quaternion representation to rotation matrix.
|
33 |
+
Args:
|
34 |
+
quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
|
35 |
+
Returns:
|
36 |
+
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
|
37 |
+
"""
|
38 |
+
norm_quat = quat
|
39 |
+
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
|
40 |
+
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
|
41 |
+
|
42 |
+
B = quat.size(0)
|
43 |
+
|
44 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
45 |
+
wx, wy, wz = w*x, w*y, w*z
|
46 |
+
xy, xz, yz = x*y, x*z, y*z
|
47 |
+
|
48 |
+
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
|
49 |
+
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
|
50 |
+
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
|
51 |
+
return rotMat
|
52 |
+
|
53 |
+
|
54 |
+
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
|
55 |
+
"""
|
56 |
+
Convert 6D rotation representation to 3x3 rotation matrix.
|
57 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
58 |
+
Args:
|
59 |
+
x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
|
62 |
+
"""
|
63 |
+
x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
|
64 |
+
a1 = x[:, :, 0]
|
65 |
+
a2 = x[:, :, 1]
|
66 |
+
b1 = F.normalize(a1)
|
67 |
+
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
|
68 |
+
b3 = torch.cross(b1, b2)
|
69 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
70 |
+
|
71 |
+
def perspective_projection(points: torch.Tensor,
|
72 |
+
translation: torch.Tensor,
|
73 |
+
focal_length: torch.Tensor,
|
74 |
+
camera_center: Optional[torch.Tensor] = None,
|
75 |
+
rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
|
76 |
+
"""
|
77 |
+
Computes the perspective projection of a set of 3D points.
|
78 |
+
Args:
|
79 |
+
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
|
80 |
+
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
|
81 |
+
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
|
82 |
+
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
|
83 |
+
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
|
84 |
+
Returns:
|
85 |
+
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
|
86 |
+
"""
|
87 |
+
batch_size = points.shape[0]
|
88 |
+
if rotation is None:
|
89 |
+
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
|
90 |
+
if camera_center is None:
|
91 |
+
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
|
92 |
+
# Populate intrinsic camera matrix K.
|
93 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
|
94 |
+
K[:,0,0] = focal_length[:,0]
|
95 |
+
K[:,1,1] = focal_length[:,1]
|
96 |
+
K[:,2,2] = 1.
|
97 |
+
K[:,:-1, -1] = camera_center
|
98 |
+
|
99 |
+
# Transform points
|
100 |
+
points = torch.einsum('bij,bkj->bki', rotation, points)
|
101 |
+
points = points + translation.unsqueeze(1)
|
102 |
+
|
103 |
+
# Apply perspective distortion
|
104 |
+
projected_points = points / points[:,:,-1].unsqueeze(-1)
|
105 |
+
|
106 |
+
# Apply camera intrinsics
|
107 |
+
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
|
108 |
+
|
109 |
+
return projected_points[:, :, :-1]
|
fourm/utils/hmr2_utils/hmr2/utils/mesh_renderer.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans and ProHMR code bases
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# https://github.com/nkolot/ProHMR
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
if 'PYOPENGL_PLATFORM' not in os.environ:
|
9 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
10 |
+
import torch
|
11 |
+
from torchvision.utils import make_grid
|
12 |
+
import numpy as np
|
13 |
+
import pyrender
|
14 |
+
import trimesh
|
15 |
+
import cv2
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from .render_openpose import render_openpose
|
19 |
+
|
20 |
+
def create_raymond_lights():
|
21 |
+
import pyrender
|
22 |
+
thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
|
23 |
+
phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
|
24 |
+
|
25 |
+
nodes = []
|
26 |
+
|
27 |
+
for phi, theta in zip(phis, thetas):
|
28 |
+
xp = np.sin(theta) * np.cos(phi)
|
29 |
+
yp = np.sin(theta) * np.sin(phi)
|
30 |
+
zp = np.cos(theta)
|
31 |
+
|
32 |
+
z = np.array([xp, yp, zp])
|
33 |
+
z = z / np.linalg.norm(z)
|
34 |
+
x = np.array([-z[1], z[0], 0.0])
|
35 |
+
if np.linalg.norm(x) == 0:
|
36 |
+
x = np.array([1.0, 0.0, 0.0])
|
37 |
+
x = x / np.linalg.norm(x)
|
38 |
+
y = np.cross(z, x)
|
39 |
+
|
40 |
+
matrix = np.eye(4)
|
41 |
+
matrix[:3,:3] = np.c_[x,y,z]
|
42 |
+
nodes.append(pyrender.Node(
|
43 |
+
light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
|
44 |
+
matrix=matrix
|
45 |
+
))
|
46 |
+
|
47 |
+
return nodes
|
48 |
+
|
49 |
+
class MeshRenderer:
|
50 |
+
|
51 |
+
def __init__(self, cfg, faces=None):
|
52 |
+
self.cfg = cfg
|
53 |
+
self.focal_length = cfg.EXTRA.FOCAL_LENGTH
|
54 |
+
self.img_res = cfg.MODEL.IMAGE_SIZE
|
55 |
+
self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
|
56 |
+
viewport_height=self.img_res,
|
57 |
+
point_size=1.0)
|
58 |
+
|
59 |
+
self.camera_center = [self.img_res // 2, self.img_res // 2]
|
60 |
+
self.faces = faces
|
61 |
+
|
62 |
+
def visualize(self, vertices, camera_translation, images, focal_length=None, nrow=3, padding=2):
|
63 |
+
images_np = np.transpose(images, (0,2,3,1))
|
64 |
+
rend_imgs = []
|
65 |
+
for i in range(vertices.shape[0]):
|
66 |
+
fl = self.focal_length
|
67 |
+
rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
|
68 |
+
rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
|
69 |
+
rend_imgs.append(torch.from_numpy(images[i]))
|
70 |
+
rend_imgs.append(rend_img)
|
71 |
+
rend_imgs.append(rend_img_side)
|
72 |
+
rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
|
73 |
+
return rend_imgs
|
74 |
+
|
75 |
+
def visualize_tensorboard(self, vertices, camera_translation, images, pred_keypoints, gt_keypoints, focal_length=None, nrow=5, padding=2):
|
76 |
+
images_np = np.transpose(images, (0,2,3,1))
|
77 |
+
rend_imgs = []
|
78 |
+
pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
|
79 |
+
pred_keypoints = self.img_res * (pred_keypoints + 0.5)
|
80 |
+
gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5)
|
81 |
+
keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9), (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)]
|
82 |
+
for i in range(vertices.shape[0]):
|
83 |
+
fl = self.focal_length
|
84 |
+
rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
|
85 |
+
rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
|
86 |
+
body_keypoints = pred_keypoints[i, :25]
|
87 |
+
extra_keypoints = pred_keypoints[i, -19:]
|
88 |
+
for pair in keypoint_matches:
|
89 |
+
body_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
|
90 |
+
pred_keypoints_img = render_openpose(255 * images_np[i].copy(), body_keypoints) / 255
|
91 |
+
body_keypoints = gt_keypoints[i, :25]
|
92 |
+
extra_keypoints = gt_keypoints[i, -19:]
|
93 |
+
for pair in keypoint_matches:
|
94 |
+
if extra_keypoints[pair[1], -1] > 0 and body_keypoints[pair[0], -1] == 0:
|
95 |
+
body_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
|
96 |
+
gt_keypoints_img = render_openpose(255*images_np[i].copy(), body_keypoints) / 255
|
97 |
+
rend_imgs.append(torch.from_numpy(images[i]))
|
98 |
+
rend_imgs.append(rend_img)
|
99 |
+
rend_imgs.append(rend_img_side)
|
100 |
+
rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1))
|
101 |
+
rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1))
|
102 |
+
rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
|
103 |
+
return rend_imgs
|
104 |
+
|
105 |
+
def __call__(self, vertices, camera_translation, image, focal_length=5000, text=None, resize=None, side_view=False, baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90):
|
106 |
+
renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
|
107 |
+
viewport_height=image.shape[0],
|
108 |
+
point_size=1.0)
|
109 |
+
material = pyrender.MetallicRoughnessMaterial(
|
110 |
+
metallicFactor=0.0,
|
111 |
+
alphaMode='OPAQUE',
|
112 |
+
baseColorFactor=baseColorFactor)
|
113 |
+
|
114 |
+
camera_translation[0] *= -1.
|
115 |
+
|
116 |
+
mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
|
117 |
+
if side_view:
|
118 |
+
rot = trimesh.transformations.rotation_matrix(
|
119 |
+
np.radians(rot_angle), [0, 1, 0])
|
120 |
+
mesh.apply_transform(rot)
|
121 |
+
rot = trimesh.transformations.rotation_matrix(
|
122 |
+
np.radians(180), [1, 0, 0])
|
123 |
+
mesh.apply_transform(rot)
|
124 |
+
mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
|
125 |
+
|
126 |
+
scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
|
127 |
+
ambient_light=(0.3, 0.3, 0.3))
|
128 |
+
scene.add(mesh, 'mesh')
|
129 |
+
|
130 |
+
camera_pose = np.eye(4)
|
131 |
+
camera_pose[:3, 3] = camera_translation
|
132 |
+
camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
|
133 |
+
camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
|
134 |
+
cx=camera_center[0], cy=camera_center[1])
|
135 |
+
scene.add(camera, pose=camera_pose)
|
136 |
+
|
137 |
+
|
138 |
+
light_nodes = create_raymond_lights()
|
139 |
+
for node in light_nodes:
|
140 |
+
scene.add_node(node)
|
141 |
+
|
142 |
+
color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
|
143 |
+
color = color.astype(np.float32) / 255.0
|
144 |
+
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
|
145 |
+
if not side_view:
|
146 |
+
output_img = (color[:, :, :3] * valid_mask +
|
147 |
+
(1 - valid_mask) * image)
|
148 |
+
else:
|
149 |
+
output_img = color[:, :, :3]
|
150 |
+
if resize is not None:
|
151 |
+
output_img = cv2.resize(output_img, resize)
|
152 |
+
|
153 |
+
output_img = output_img.astype(np.float32)
|
154 |
+
renderer.delete()
|
155 |
+
return output_img
|
fourm/utils/hmr2_utils/hmr2/utils/render_openpose.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on the 4DHumans and ProHMR code bases
|
3 |
+
# https://github.com/shubham-goel/4D-Humans
|
4 |
+
# https://github.com/nkolot/ProHMR
|
5 |
+
# --------------------------------------------------------
|
6 |
+
|
7 |
+
"""
|
8 |
+
Render OpenPose keypoints.
|
9 |
+
Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp
|
10 |
+
"""
|
11 |
+
import cv2
|
12 |
+
import math
|
13 |
+
import numpy as np
|
14 |
+
from typing import List, Tuple
|
15 |
+
|
16 |
+
def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
|
17 |
+
"""
|
18 |
+
Compute rectangle enclosing keypoints above the threshold.
|
19 |
+
Args:
|
20 |
+
keypoints (np.array): Keypoint array of shape (N, 3).
|
21 |
+
threshold (float): Confidence visualization threshold.
|
22 |
+
Returns:
|
23 |
+
Tuple[float, float, float]: Rectangle width, height and area.
|
24 |
+
"""
|
25 |
+
valid_ind = keypoints[:, -1] > threshold
|
26 |
+
if valid_ind.sum() > 0:
|
27 |
+
valid_keypoints = keypoints[valid_ind][:, :-1]
|
28 |
+
max_x = valid_keypoints[:,0].max()
|
29 |
+
max_y = valid_keypoints[:,1].max()
|
30 |
+
min_x = valid_keypoints[:,0].min()
|
31 |
+
min_y = valid_keypoints[:,1].min()
|
32 |
+
width = max_x - min_x
|
33 |
+
height = max_y - min_y
|
34 |
+
area = width * height
|
35 |
+
return width, height, area
|
36 |
+
else:
|
37 |
+
return 0,0,0
|
38 |
+
|
39 |
+
def render_keypoints(img: np.array,
|
40 |
+
keypoints: np.array,
|
41 |
+
pairs: List,
|
42 |
+
colors: List,
|
43 |
+
thickness_circle_ratio: float,
|
44 |
+
thickness_line_ratio_wrt_circle: float,
|
45 |
+
pose_scales: List,
|
46 |
+
threshold: float = 0.1) -> np.array:
|
47 |
+
"""
|
48 |
+
Render keypoints on input image.
|
49 |
+
Args:
|
50 |
+
img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
|
51 |
+
keypoints (np.array): Keypoint array of shape (N, 3).
|
52 |
+
pairs (List): List of keypoint pairs per limb.
|
53 |
+
colors: (List): List of colors per keypoint.
|
54 |
+
thickness_circle_ratio (float): Circle thickness ratio.
|
55 |
+
thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle.
|
56 |
+
pose_scales (List): List of pose scales.
|
57 |
+
threshold (float): Only visualize keypoints with confidence above the threshold.
|
58 |
+
Returns:
|
59 |
+
(np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
|
60 |
+
"""
|
61 |
+
img_orig = img.copy()
|
62 |
+
width, height = img.shape[1], img.shape[2]
|
63 |
+
area = width * height
|
64 |
+
|
65 |
+
lineType = 8
|
66 |
+
shift = 0
|
67 |
+
numberColors = len(colors)
|
68 |
+
thresholdRectangle = 0.1
|
69 |
+
|
70 |
+
person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle)
|
71 |
+
if person_area > 0:
|
72 |
+
ratioAreas = min(1, max(person_width / width, person_height / height))
|
73 |
+
thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2)
|
74 |
+
thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
|
75 |
+
thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle))
|
76 |
+
radius = thicknessRatio / 2
|
77 |
+
|
78 |
+
img = np.ascontiguousarray(img.copy())
|
79 |
+
for i, pair in enumerate(pairs):
|
80 |
+
index1, index2 = pair
|
81 |
+
if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold:
|
82 |
+
thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0]))
|
83 |
+
colorIndex = index2
|
84 |
+
color = colors[colorIndex % numberColors]
|
85 |
+
keypoint1 = keypoints[index1, :-1].astype(np.int)
|
86 |
+
keypoint2 = keypoints[index2, :-1].astype(np.int)
|
87 |
+
cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift)
|
88 |
+
for part in range(len(keypoints)):
|
89 |
+
faceIndex = part
|
90 |
+
if keypoints[faceIndex, -1] > threshold:
|
91 |
+
radiusScaled = int(round(radius[faceIndex] * pose_scales[0]))
|
92 |
+
thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0]))
|
93 |
+
colorIndex = part
|
94 |
+
color = colors[colorIndex % numberColors]
|
95 |
+
center = keypoints[faceIndex, :-1].astype(np.int)
|
96 |
+
cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift)
|
97 |
+
return img
|
98 |
+
|
99 |
+
def render_body_keypoints(img: np.array,
|
100 |
+
body_keypoints: np.array) -> np.array:
|
101 |
+
"""
|
102 |
+
Render OpenPose body keypoints on input image.
|
103 |
+
Args:
|
104 |
+
img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
|
105 |
+
body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
|
106 |
+
Returns:
|
107 |
+
(np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
|
108 |
+
"""
|
109 |
+
|
110 |
+
thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0])
|
111 |
+
thickness_line_ratio_wrt_circle = 0.75
|
112 |
+
pairs = []
|
113 |
+
pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24]
|
114 |
+
pairs = np.array(pairs).reshape(-1,2)
|
115 |
+
colors = [255., 0., 85.,
|
116 |
+
255., 0., 0.,
|
117 |
+
255., 85., 0.,
|
118 |
+
255., 170., 0.,
|
119 |
+
255., 255., 0.,
|
120 |
+
170., 255., 0.,
|
121 |
+
85., 255., 0.,
|
122 |
+
0., 255., 0.,
|
123 |
+
255., 0., 0.,
|
124 |
+
0., 255., 85.,
|
125 |
+
0., 255., 170.,
|
126 |
+
0., 255., 255.,
|
127 |
+
0., 170., 255.,
|
128 |
+
0., 85., 255.,
|
129 |
+
0., 0., 255.,
|
130 |
+
255., 0., 170.,
|
131 |
+
170., 0., 255.,
|
132 |
+
255., 0., 255.,
|
133 |
+
85., 0., 255.,
|
134 |
+
0., 0., 255.,
|
135 |
+
0., 0., 255.,
|
136 |
+
0., 0., 255.,
|
137 |
+
0., 255., 255.,
|
138 |
+
0., 255., 255.,
|
139 |
+
0., 255., 255.]
|
140 |
+
colors = np.array(colors).reshape(-1,3)
|
141 |
+
pose_scales = [1]
|
142 |
+
return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
|
143 |
+
|
144 |
+
def render_openpose(img: np.array,
|
145 |
+
body_keypoints: np.array) -> np.array:
|
146 |
+
"""
|
147 |
+
Render keypoints in the OpenPose format on input image.
|
148 |
+
Args:
|
149 |
+
img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
|
150 |
+
body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
|
151 |
+
Returns:
|
152 |
+
(np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
|
153 |
+
"""
|
154 |
+
img = render_body_keypoints(img, body_keypoints)
|
155 |
+
return img
|