aroraaman commited on
Commit
3424266
1 Parent(s): 453e23a

Add all of `fourm`

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fourm/data/__init__.py +4 -0
  3. fourm/data/dataset_utils.py +85 -0
  4. fourm/data/image_augmenter.py +186 -0
  5. fourm/data/masking.py +747 -0
  6. fourm/data/modality_info.py +427 -0
  7. fourm/data/modality_transforms.py +1387 -0
  8. fourm/data/multimodal_dataset_folder.py +363 -0
  9. fourm/data/pretrain_utils.py +292 -0
  10. fourm/data/transfer_utils.py +53 -0
  11. fourm/data/unified_datasets.py +557 -0
  12. fourm/demo_4M_sampler.py +540 -0
  13. fourm/models/__init__.py +0 -0
  14. fourm/models/decoder_embeddings.py +268 -0
  15. fourm/models/encoder_embeddings.py +422 -0
  16. fourm/models/fm.py +1130 -0
  17. fourm/models/fm_utils.py +387 -0
  18. fourm/models/fm_vit.py +485 -0
  19. fourm/models/generate.py +1273 -0
  20. fourm/models/lora_utils.py +177 -0
  21. fourm/utils/__init__.py +22 -0
  22. fourm/utils/checkpoint.py +185 -0
  23. fourm/utils/clip/__init__.py +2 -0
  24. fourm/utils/clip/clip.py +236 -0
  25. fourm/utils/clip/model.py +504 -0
  26. fourm/utils/clip/simple_tokenizer.py +137 -0
  27. fourm/utils/data_constants.py +40 -0
  28. fourm/utils/dist.py +100 -0
  29. fourm/utils/fsdp_utils.py +116 -0
  30. fourm/utils/generation.py +99 -0
  31. fourm/utils/generation_datasets/PartiPrompts.tsv +0 -0
  32. fourm/utils/generation_datasets/__init__.py +3 -0
  33. fourm/utils/generation_datasets/empty_dataset.py +27 -0
  34. fourm/utils/generation_datasets/image_caption_dataset.py +99 -0
  35. fourm/utils/generation_datasets/parti_prompts_dataset.py +114 -0
  36. fourm/utils/hmr2_utils/hmr2/__init__.py +0 -0
  37. fourm/utils/hmr2_utils/hmr2/models/__init__.py +2 -0
  38. fourm/utils/hmr2_utils/hmr2/models/backbones/__init__.py +13 -0
  39. fourm/utils/hmr2_utils/hmr2/models/backbones/vit.py +353 -0
  40. fourm/utils/hmr2_utils/hmr2/models/components/__init__.py +0 -0
  41. fourm/utils/hmr2_utils/hmr2/models/components/pose_transformer.py +363 -0
  42. fourm/utils/hmr2_utils/hmr2/models/components/t_cond_mlp.py +204 -0
  43. fourm/utils/hmr2_utils/hmr2/models/heads/__init__.py +1 -0
  44. fourm/utils/hmr2_utils/hmr2/models/heads/smpl_head.py +116 -0
  45. fourm/utils/hmr2_utils/hmr2/models/hmr2.py +117 -0
  46. fourm/utils/hmr2_utils/hmr2/models/smpl_wrapper.py +47 -0
  47. fourm/utils/hmr2_utils/hmr2/utils/__init__.py +31 -0
  48. fourm/utils/hmr2_utils/hmr2/utils/geometry.py +109 -0
  49. fourm/utils/hmr2_utils/hmr2/utils/mesh_renderer.py +155 -0
  50. fourm/utils/hmr2_utils/hmr2/utils/render_openpose.py +155 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fourm/utils/clip/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
fourm/data/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .image_augmenter import *
2
+ from .modality_transforms import *
3
+ from .unified_datasets import build_fm_pretraining_dataset, build_fm_transfer_dataset, build_wds_fm_pretraining_dataloader, build_wds_divae_dataloader, build_huggingface_pretraining_dataloader, build_mixture_dataloader
4
+ from .pretrain_utils import *
fourm/data/dataset_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ from torch.utils.data import Dataset
16
+
17
+
18
+ class RepeatedDatasetWrapper(Dataset):
19
+ def __init__(self, original_dataset, num_repeats):
20
+ """
21
+ Dataset wrapper that repeats the original dataset n times.
22
+
23
+ Args:
24
+ original_dataset (torch.utils.data.Dataset): The original dataset to be repeated.
25
+ num_repeats (int): The number of times the dataset should be repeated.
26
+ """
27
+ self.original_dataset = original_dataset
28
+ self.num_repeats = num_repeats
29
+
30
+ def __getitem__(self, index):
31
+ """
32
+ Retrieve the item at the given index.
33
+
34
+ Args:
35
+ index (int): The index of the item to be retrieved.
36
+ """
37
+ original_index = index % len(self.original_dataset)
38
+ return self.original_dataset[original_index]
39
+
40
+ def __len__(self):
41
+ """
42
+ Get the length of the dataset after repeating it n times.
43
+
44
+ Returns:
45
+ int: The length of the dataset.
46
+ """
47
+ return len(self.original_dataset) * self.num_repeats
48
+
49
+
50
+ class SubsampleDatasetWrapper(Dataset):
51
+ def __init__(self, original_dataset, dataset_size, seed=0, return_orig_idx=False):
52
+ """
53
+ Dataset wrapper that randomly subsamples the original dataset.
54
+
55
+ Args:
56
+ original_dataset (torch.utils.data.Dataset): The original dataset to be subsampled.
57
+ dataset_size (int): The size of the subsampled dataset.
58
+ seed (int): The seed to use for selecting the subset of indices of the original dataset.
59
+ return_orig_idx (bool): Whether to return the original index of the item in the original dataset.
60
+ """
61
+ self.original_dataset = original_dataset
62
+ self.dataset_size = dataset_size or len(original_dataset)
63
+ self.return_orig_idx = return_orig_idx
64
+ np.random.seed(seed)
65
+ self.indices = np.random.permutation(len(self.original_dataset))[:self.dataset_size]
66
+
67
+ def __getitem__(self, index):
68
+ """
69
+ Retrieve the item at the given index.
70
+
71
+ Args:
72
+ index (int): The index of the item to be retrieved.
73
+ """
74
+ original_index = self.indices[index]
75
+ sample = self.original_dataset[original_index]
76
+ return sample, original_index if self.return_orig_idx else sample
77
+
78
+ def __len__(self):
79
+ """
80
+ Get the length of the dataset after subsampling it.
81
+
82
+ Returns:
83
+ int: The length of the dataset.
84
+ """
85
+ return len(self.indices)
fourm/data/image_augmenter.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import random
15
+ from abc import ABC, abstractmethod
16
+
17
+ import numpy as np
18
+ import torchvision
19
+
20
+ from fourm.utils import to_2tuple
21
+
22
+
23
+ class AbstractImageAugmenter(ABC):
24
+ """Abstract class for image augmenters.
25
+ """
26
+
27
+ @abstractmethod
28
+ def __call__(self, mod_dict, crop_settings):
29
+ pass
30
+
31
+
32
+ class RandomCropImageAugmenter(AbstractImageAugmenter):
33
+
34
+ def __init__(self, target_size=224, hflip=0.5, crop_scale=(0.2, 1.0), crop_ratio=(0.75, 1.3333), main_domain='rgb'):
35
+
36
+ self.target_size = to_2tuple(target_size)
37
+ self.hflip = hflip
38
+ self.crop_scale = crop_scale
39
+ self.crop_ratio = crop_ratio
40
+ self.main_domain = main_domain
41
+
42
+ def __call__(self, mod_dict, crop_settings):
43
+
44
+ if crop_settings is not None:
45
+ raise ValueError("Crop settings are provided but not used by this augmenter.")
46
+
47
+ image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
48
+ # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
49
+ orig_width, orig_height = image.size
50
+ orig_size = (orig_height, orig_width)
51
+
52
+ top, left, h, w = torchvision.transforms.RandomResizedCrop.get_params(
53
+ image, scale=self.crop_scale, ratio=self.crop_ratio
54
+ )
55
+ crop_coords = top, left, h, w
56
+ flip = random.random() < self.hflip
57
+ rand_aug_idx = None
58
+
59
+ return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
60
+
61
+ class NoImageAugmenter(AbstractImageAugmenter): # this is for non-image modalities like poses where we don't do any augs, e.g. during tokenization
62
+
63
+ def __init__(self, no_aug=True, main_domain='human_poses'):
64
+ self.target_size = None #to_2tuple(target_size)
65
+ self.no_aug = no_aug
66
+ self.main_domain = main_domain
67
+
68
+ def __call__(self, mod_dict, crop_settings):
69
+ # # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
70
+ orig_size = (224, 224)
71
+
72
+ rand_aug_idx = 0
73
+ top, left, h, w, flip = 0, 0, 224, 224, 0
74
+ crop_coords = (top, left, h, w)
75
+
76
+ return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
77
+
78
+ class PreTokenizedImageAugmenter(AbstractImageAugmenter):
79
+
80
+ def __init__(self, target_size, no_aug=False, main_domain='rgb'):
81
+ self.target_size = to_2tuple(target_size)
82
+ self.no_aug = no_aug
83
+ self.main_domain = main_domain
84
+
85
+ def __call__(self, mod_dict, crop_settings):
86
+ # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
87
+ if self.main_domain in mod_dict and 'tok' not in self.main_domain:
88
+ image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
89
+ orig_width, orig_height = image.size
90
+ orig_size = (orig_height, orig_width)
91
+ else:
92
+ orig_size = None
93
+
94
+ rand_aug_idx = 0 if self.no_aug else np.random.randint(len(crop_settings))
95
+ top, left, h, w, flip = crop_settings[rand_aug_idx]
96
+ crop_coords = (top, left, h, w)
97
+
98
+ return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
99
+
100
+
101
+ class CenterCropImageAugmenter(AbstractImageAugmenter):
102
+ def __init__(self, target_size, hflip=0.0, main_domain='rgb'):
103
+ self.target_size = to_2tuple(target_size)
104
+ self.hflip = hflip
105
+ self.main_domain = main_domain
106
+
107
+ def __call__(self, mod_dict, crop_settings=None):
108
+ image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
109
+ orig_width, orig_height = image.size
110
+ orig_size = (orig_height, orig_width)
111
+
112
+ if orig_height > orig_width:
113
+ h = w = orig_width
114
+ top = (orig_height - orig_width) // 2
115
+ left = 0
116
+ else:
117
+ h = w = orig_height
118
+ top = 0
119
+ left = (orig_width - orig_height) // 2
120
+
121
+ crop_coords = (top, left, h, w)
122
+ flip = random.random() < self.hflip
123
+ rand_aug_idx = None
124
+
125
+ return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
126
+
127
+
128
+ class PaddingImageAugmenter(AbstractImageAugmenter):
129
+ def __init__(self, target_size, hflip=0.0, main_domain='rgb'):
130
+ self.target_size = to_2tuple(target_size)
131
+ self.hflip = hflip
132
+ self.main_domain = main_domain
133
+
134
+ def __call__(self, mod_dict, crop_settings):
135
+ image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
136
+ orig_width, orig_height = image.size
137
+ orig_size = (orig_height, orig_width)
138
+
139
+ h = w = max(orig_width, orig_height)
140
+ top = left = 0
141
+ crop_coords = (top, left, h, w)
142
+ flip = random.random() < self.hflip
143
+ rand_aug_idx = None
144
+
145
+ return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
146
+
147
+
148
+ class ScaleJitteringImageAugmenter(AbstractImageAugmenter):
149
+ def __init__(self, target_size, hflip=0.0, scale=(0.1, 2.0), main_domain='rgb'):
150
+ self.target_size = to_2tuple(target_size)
151
+ self.hflip = hflip
152
+ self.scale = scale
153
+ self.main_domain = main_domain
154
+
155
+ def scale_jitter(self, orig_height, orig_width):
156
+ rand_scale = np.random.uniform(self.scale[0], self.scale[1])
157
+ max_hw = max(orig_height, orig_width)
158
+ h = w = round(max_hw / rand_scale)
159
+ top = round(max(0, np.random.uniform(0, orig_height - h)))
160
+ left = round(max(0, np.random.uniform(0, orig_width - w)))
161
+
162
+ return top, left, h, w
163
+
164
+ def __call__(self, mod_dict, crop_settings):
165
+
166
+ if crop_settings is not None:
167
+ raise ValueError("Crop settings are provided but not used by this augmenter.")
168
+
169
+ image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]]
170
+ # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image)
171
+ orig_width, orig_height = image.size
172
+ orig_size = (orig_height, orig_width)
173
+
174
+ crop_coords = self.scale_jitter(orig_height, orig_width)
175
+ flip = random.random() < self.hflip
176
+ rand_aug_idx = None
177
+
178
+ return crop_coords, flip, orig_size, self.target_size, rand_aug_idx
179
+
180
+
181
+ class EmptyAugmenter(AbstractImageAugmenter):
182
+ def __init__(self):
183
+ pass
184
+
185
+ def __call__(self, mod_dict, crop_settings):
186
+ return None, None, None, None, None
fourm/data/masking.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import random
16
+ from typing import Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from einops import rearrange
22
+ from tokenizers import Tokenizer
23
+ from torch.distributions import Dirichlet
24
+
25
+ from fourm.data.modality_transforms import get_transform_key
26
+ from fourm.utils import to_2tuple
27
+ from fourm.utils.tokenizer import get_sentinel_to_id_mapping
28
+
29
+
30
+ def sample_cosine(min_val: float = 0, max_val: float =1) -> float:
31
+ """Sample a value from a cosine distribution between min_val and max_val
32
+
33
+ Args:
34
+ min_val: Minimum value
35
+ max_val: Maximum value
36
+
37
+ Returns:
38
+ Sampled value
39
+ """
40
+
41
+ return min_val + 0.5 * (max_val - min_val) * (1 + math.cos(math.pi * random.uniform(0, 1)))
42
+
43
+
44
+ def sample_uniform(min_val: float = 0, max_val: float =1) -> float:
45
+ """Sample a value from a uniform distribution between min_val and max_val
46
+
47
+ Args:
48
+ min_val: Minimum value
49
+ max_val: Maximum value
50
+
51
+ Returns:
52
+ Sampled value
53
+ """
54
+
55
+ return random.uniform(min_val, max_val)
56
+
57
+
58
+ def simple_span_masking(sequence: List[int], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
59
+ """Span masking for a sequence
60
+
61
+ Args:
62
+ sequence: Sequence to mask
63
+ sentinel_to_id: Mapping from sentinel to id
64
+ keep_prob: Probability of keeping a token
65
+
66
+ Returns:
67
+ Masked input sequence and masked target sequence
68
+ """
69
+ sequence_length = len(sequence)
70
+ # 0 for keep, 1 for mask
71
+ masks = torch.where(torch.rand(sequence_length) <= keep_prob, 0, 1).bool().tolist()
72
+
73
+ input_sequence = []
74
+ target_sequence = []
75
+
76
+ prev_mask = False
77
+ sentinel_count = 0
78
+ for token, mask in zip(sequence, masks):
79
+ if mask:
80
+ if not prev_mask:
81
+ sentinel_count += 1
82
+ input_sequence.append(sentinel_to_id[sentinel_count])
83
+ target_sequence.append(sentinel_to_id[sentinel_count])
84
+ prev_mask = True
85
+ target_sequence.append(token)
86
+ else:
87
+ prev_mask = False
88
+ input_sequence.append(token)
89
+
90
+ target_sequence.append(sentinel_to_id[sentinel_count + 1])
91
+ return input_sequence, target_sequence
92
+
93
+
94
+ def chunk_span_masking(sequence_chunks: List[List[int]], sentinel_to_id: Dict[int, int], keep_prob: float) -> Tuple[List[int], List[int]]:
95
+ """Span masking where masking is performed at the chunk level.
96
+
97
+ Args:
98
+ sequence_chunks: Sequence chunks to mask
99
+ sentinel_to_id: Mapping from sentinel to id
100
+ keep_prob: Probability of keeping a token
101
+
102
+ Returns:
103
+ Masked input sequence and masked target sequence
104
+ """
105
+ chunk_length = len(sequence_chunks)
106
+ # 0 for keep, 1 for mask
107
+ masks = torch.where(torch.rand(chunk_length) <= keep_prob, 0, 1).bool().tolist()
108
+
109
+ input_sequence = []
110
+ target_sequence = []
111
+
112
+ prev_mask = False
113
+ sentinel_count = 0
114
+ for chunk, mask in zip(sequence_chunks, masks):
115
+ if mask:
116
+ if not prev_mask:
117
+ sentinel_count += 1
118
+ input_sequence.append(sentinel_to_id[sentinel_count])
119
+ target_sequence.append(sentinel_to_id[sentinel_count])
120
+ prev_mask = True
121
+ target_sequence.extend(chunk)
122
+ else:
123
+ prev_mask = False
124
+ input_sequence.extend(chunk)
125
+
126
+ target_sequence.append(sentinel_to_id[sentinel_count + 1])
127
+ return input_sequence, target_sequence
128
+
129
+
130
+
131
+ class UnifiedMasking(object):
132
+ def __init__(self,
133
+ modality_info: Dict,
134
+ text_tokenizer: Optional[Tokenizer],
135
+ input_tokens_range: Union[int, Tuple[int, int]],
136
+ target_tokens_range: Optional[Union[int, Tuple[int, int]]],
137
+ max_tries: int = 100,
138
+ sampling_weights: Optional[List[float]] = None,):
139
+ """Performs masking on a dict of modalities (both image based and sequence based modalities)
140
+
141
+ Args:
142
+ modality_info: Dict with the modalities and their corresponding information
143
+ text_tokenizer: Tokenizer to use for text modalities
144
+ input_tokens_range: Range of number of tokens to mask in the input
145
+ target_tokens_range: Range of number of tokens to mask in the target
146
+ max_tries: Maximum number of tries to find a valid token budgets
147
+ sampling_weights: Sampling weights for the mixture of Dirichlet distributions
148
+ """
149
+ self.input_tokens_range = to_2tuple(input_tokens_range)
150
+ self.target_tokens_range = to_2tuple(target_tokens_range) if target_tokens_range is not None else None
151
+ self.modality_info = modality_info
152
+ self.num_modalities = len(modality_info)
153
+ self.max_tries = max_tries
154
+ self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
155
+ self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
156
+ self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])
157
+
158
+ # Dirichlet sampling (supports a mixture of multiple Dirichlet distributions)
159
+ eps = 1e-9
160
+ input_alphas = torch.tensor([mod["input_alphas"] for mod in modality_info.values()])
161
+ input_alphas = rearrange(input_alphas, "nmod nmix -> nmix nmod")
162
+ self.input_dirichlets = [Dirichlet(torch.clamp(input_alpha, min=eps)) for input_alpha in input_alphas]
163
+ target_alphas = torch.tensor([mod["target_alphas"] for mod in modality_info.values()])
164
+ target_alphas = rearrange(target_alphas, "nmod nmix -> nmix nmod")
165
+ self.target_dirichlets = [Dirichlet(torch.clamp(target_alpha, min=eps)) for target_alpha in target_alphas]
166
+ assert(len(self.input_dirichlets) == len(self.target_dirichlets))
167
+ self.num_dirichlets = len(self.input_dirichlets)
168
+ if sampling_weights is not None:
169
+ assert len(sampling_weights) == self.num_dirichlets
170
+ self.sampling_weights = torch.tensor(sampling_weights)
171
+ else:
172
+ self.sampling_weights = None
173
+
174
+ self.text_tokenizer = text_tokenizer
175
+ self.keep_prob_decay_factor = 0.9
176
+ self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
177
+ self.sentinel_ids = set(self.sentinel_to_id.values())
178
+ self.pad_id = text_tokenizer.token_to_id("[PAD]")
179
+ self.eos_id = text_tokenizer.token_to_id("[EOS]")
180
+
181
+ def input_token_budget(self, num_input_tokens, dir_idx=0):
182
+ """Sample a token budget for the input
183
+
184
+ Args:
185
+ num_input_tokens: Number of tokens in the input
186
+
187
+ Returns:
188
+ Token budget for the input
189
+ """
190
+ # Get the number of tokens for each modality
191
+ for i in range(self.max_tries):
192
+ input_token_budget = (self.input_dirichlets[dir_idx].sample() * num_input_tokens).floor().int()
193
+ diff = num_input_tokens - input_token_budget.sum()
194
+ # Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
195
+ # This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
196
+ input_token_budget += torch.bincount(self.input_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(input_token_budget))
197
+
198
+ # If token budget is over max tokens for a given modality, set it to max
199
+ input_token_budget = torch.clamp(input_token_budget, max=self.max_tokens)
200
+
201
+ if (input_token_budget >= self.min_tokens).all():
202
+ return input_token_budget.tolist()
203
+
204
+ print(f"More than max tries for input!")
205
+ return input_token_budget.tolist()
206
+
207
+ def target_token_budget(self, input_token_budget, num_target_tokens, dir_idx=0):
208
+ """Sample a token budget for the target
209
+
210
+ Args:
211
+ input_token_budget: Token budget for the input
212
+ num_target_tokens: Number of tokens in the target
213
+
214
+ Returns:
215
+ Token budget for the target
216
+ """
217
+ # We don't reduce the number of tokens for sequence based tasks
218
+ max_tokens_remaining = torch.where(self.mod_is_img, self.max_tokens - torch.tensor(input_token_budget), self.max_tokens)
219
+ max_tokens_remaining = torch.max(self.min_tokens, max_tokens_remaining)
220
+ for i in range(self.max_tries):
221
+ target_token_budget = (self.target_dirichlets[dir_idx].sample() * num_target_tokens).floor().int()
222
+ diff = num_target_tokens - target_token_budget.sum()
223
+ # Adds the remaining tokens by sampling from the Dirichlet and taking the argmax
224
+ # This avoids adding tokens to modalities that shouldn't be sampled (i.e. with alphas ~=0)
225
+ target_token_budget += torch.bincount(self.target_dirichlets[dir_idx].sample_n(diff).argmax(dim=-1), minlength=len(target_token_budget))
226
+
227
+ # If token budget is over max tokens for a given modality, set it to max
228
+ target_token_budget = torch.clamp(target_token_budget, max=max_tokens_remaining)
229
+
230
+ if (target_token_budget >= self.min_tokens).all():
231
+ return target_token_budget.tolist()
232
+
233
+ print(f"More than max tries for target!")
234
+ return target_token_budget.tolist()
235
+
236
+ def image_mask(self, tensor: torch.Tensor, num_tokens: int, input_budget: int, target_budget: int):
237
+ """Applies input and target masking to an image tensor
238
+
239
+ Args:
240
+ tensor: Image tensor
241
+ num_tokens: Number of tokens in the tensor
242
+ input_budget: Token budget for the input
243
+ target_budget: Token budget for the target
244
+
245
+ Returns:
246
+ Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
247
+ """
248
+ noise = torch.rand(num_tokens)
249
+ ids_shuffle = torch.argsort(noise, dim=0)
250
+
251
+ input_mask = torch.ones(num_tokens, dtype=torch.bool)
252
+ input_mask[:input_budget] = 0
253
+ input_mask = torch.gather(input_mask, dim=0, index=ids_shuffle)
254
+
255
+ if target_budget is None:
256
+ target_mask = ~input_mask
257
+ else:
258
+ target_mask = torch.ones(num_tokens, dtype=torch.bool)
259
+ target_mask[input_budget:input_budget + target_budget] = 0
260
+ target_mask = torch.gather(target_mask, dim=0, index=ids_shuffle)
261
+
262
+ decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
263
+ first_mask_token = torch.argmin(target_mask + torch.arange(target_mask.shape[0], device=target_mask.device) * 1e-6)
264
+ decoder_attention_mask[first_mask_token] = (~target_mask).sum() # Equiv. to target budget
265
+
266
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
267
+
268
+ def sequence_token_mask(self, sequence_ids: str, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str, vocab_offset: int):
269
+ """Applies input and target masking to a sequence of tokens (e.g. DINOv2 global tokens)
270
+ The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
271
+ If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.
272
+
273
+ Args:
274
+ sequence_ids: Sequence ids
275
+ max_tokens: Maximum number of tokens in the sequence
276
+ input_budget: Token budget for the input
277
+ target_budget: Token budget for the target
278
+ keep_scheme: Scheme for sampling the keep probability
279
+ vocab_offset: Offset to avoid overlap with sentinel tokens
280
+
281
+ Returns:
282
+ Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
283
+ """
284
+ seq_ids = sequence_ids
285
+ seq_ids = seq_ids + vocab_offset # Avoid overlap with sentinel tokens (needs to be substracted after decoding)
286
+
287
+ # If input budget is 0, treat it as if the whole sequence is completely masked
288
+ if input_budget == 0:
289
+ keep_prob = 0.
290
+ input_seq_ids = []
291
+ _, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
292
+ else:
293
+ if keep_scheme == 'random':
294
+ keep_prob = sample_uniform(0, 1)
295
+ elif keep_scheme == 'all':
296
+ keep_prob = 1.0
297
+ elif keep_scheme == 'binary':
298
+ keep_prob = random.choice([0., 1.])
299
+ else:
300
+ raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
301
+
302
+ input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
303
+ # Keep lowering the keep_prob while we are over-budget
304
+ while len(input_seq_ids) > input_budget:
305
+ keep_prob = keep_prob * self.keep_prob_decay_factor
306
+ input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
307
+
308
+ # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
309
+ max_length = (max_tokens + 1) * 2
310
+ tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
311
+ input_mask = torch.ones(max_length, dtype=torch.bool)
312
+ target_mask = torch.ones(max_length, dtype=torch.bool)
313
+ decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
314
+
315
+ # Set input and input mask
316
+ tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
317
+ input_mask[:len(input_seq_ids)] = 0
318
+
319
+ if target_budget is None or len(target_seq_ids) <= target_budget:
320
+ tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
321
+ target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
322
+ decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
323
+ else:
324
+ # Randomly choose sentinel token.
325
+ sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
326
+ # If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
327
+ chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
328
+ # If length starting at this token g.t. budget, truncate until budget is reached
329
+ if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
330
+ target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
331
+ # Otherwise, select earliest sentinel token such that we don't go over budget
332
+ # Note: We could also use the randomly chosen sentinel token, but that would waste budget
333
+ else:
334
+ for idx in sentinel_indices:
335
+ if len(target_seq_ids) - idx <= target_budget:
336
+ target_seq_ids = target_seq_ids[idx:]
337
+ break
338
+
339
+ tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
340
+ target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
341
+ decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
342
+
343
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
344
+
345
+ def sequence_mask(self, sequence: Union[str, List[str]], max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
346
+ """Applies input and target masking to a sequence
347
+
348
+ The keep probability is sampled from a cosine schedule and does not depend on the number of tokens in the sequence.
349
+ If the keep probability results in a sequence that is too long, then it is lowered until the sequence is short enough.
350
+
351
+ Args:
352
+ sequence: Sequence, can be either a str or list of strings
353
+ max_tokens: Maximum number of tokens in the sequence
354
+ input_budget: Token budget for the input
355
+ target_budget: Token budget for the target
356
+ keep_scheme: Scheme for sampling the keep probability
357
+
358
+ Returns:
359
+ Dictionary containing the masked sequence tensor, the input mask, the target mask, and the decoder attention mask
360
+ """
361
+ if isinstance(sequence, str):
362
+ # Tokenize the sequence and get the ids
363
+ seq_ids: List[int] = self.text_tokenizer.encode(sequence).ids
364
+ # Add EOS to all sequences
365
+ seq_ids.append(self.eos_id)
366
+ # Truncate sequence
367
+ seq_ids = seq_ids[:max_tokens]
368
+
369
+ # Use default span masking
370
+ span_masking_fn = simple_span_masking
371
+
372
+ elif isinstance(sequence, list):
373
+ # Tokenize the sequence chunks and get the ids
374
+ encoded_seq_chunks = self.text_tokenizer.encode_batch(sequence)
375
+ seq_ids: List[List[int]] = [seq.ids for seq in encoded_seq_chunks]
376
+ # Add EOS as an extra chunk
377
+ seq_ids.append([self.eos_id])
378
+ # Truncate sequence to keep all chunks below max token length
379
+ cumulative_token_count = np.cumsum(np.array([len(chunk) for chunk in seq_ids]))
380
+ seq_ids = [chunk for (chunk, token_count) in zip(seq_ids, cumulative_token_count) if token_count <= max_tokens]
381
+
382
+ # Span mask over chunks
383
+ span_masking_fn = chunk_span_masking
384
+
385
+ else:
386
+ raise ValueError(f"Invalid sequence: {sequence}")
387
+
388
+
389
+ # If input budget is 0, treat it as if the whole sequence is completely masked
390
+ if input_budget == 0:
391
+ keep_prob = 0.
392
+ input_seq_ids = []
393
+ _, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
394
+ else:
395
+ if keep_scheme == 'random':
396
+ keep_prob = sample_uniform(0, 1)
397
+ elif keep_scheme == 'all':
398
+ keep_prob = 1.0
399
+ elif keep_scheme == 'binary':
400
+ keep_prob = random.choice([0., 1.])
401
+ else:
402
+ raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
403
+
404
+ input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
405
+ # Keep lowering the keep_prob while we are over-budget
406
+ while len(input_seq_ids) > input_budget:
407
+ keep_prob = keep_prob * self.keep_prob_decay_factor
408
+ input_seq_ids, target_seq_ids = span_masking_fn(seq_ids, self.sentinel_to_id, keep_prob)
409
+
410
+ # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
411
+ max_length = (max_tokens + 1) * 2
412
+ tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
413
+ input_mask = torch.ones(max_length, dtype=torch.bool)
414
+ target_mask = torch.ones(max_length, dtype=torch.bool)
415
+ decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
416
+
417
+ # Set input and input mask
418
+ tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
419
+ input_mask[:len(input_seq_ids)] = 0
420
+
421
+ if target_budget is None or len(target_seq_ids) <= target_budget:
422
+ tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
423
+ target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
424
+ decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
425
+ else:
426
+ # Randomly choose sentinel token.
427
+ sentinel_indices = [i for i, token_id in enumerate(target_seq_ids) if token_id in self.sentinel_ids]
428
+ # If there is more than 1 sentinel, avoid sampling the very last one which indicates the end of the sequence
429
+ chosen_sentinel = np.random.randint(max(1, len(sentinel_indices) - 1))
430
+ # If length starting at this token g.t. budget, truncate until budget is reached
431
+ if len(target_seq_ids) - sentinel_indices[chosen_sentinel] >= target_budget:
432
+ target_seq_ids = target_seq_ids[sentinel_indices[chosen_sentinel]:sentinel_indices[chosen_sentinel] + target_budget]
433
+ # Otherwise, select earliest sentinel token such that we don't go over budget
434
+ # Note: We could also use the randomly chosen sentinel token, but that would waste budget
435
+ else:
436
+ for idx in sentinel_indices:
437
+ if len(target_seq_ids) - idx <= target_budget:
438
+ target_seq_ids = target_seq_ids[idx:]
439
+ break
440
+
441
+ tensor[input_budget:input_budget + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
442
+ target_mask[input_budget:input_budget + len(target_seq_ids)] = 0
443
+ decoder_attention_mask[input_budget:input_budget + len(target_seq_ids)] = 1
444
+
445
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
446
+
447
+
448
+ def sequence_emb_mask_span(self, emb_tensor: torch.Tensor, max_tokens: int, input_budget: int, target_budget: int, keep_scheme: str):
449
+ """Applies input masking to an sequence embedding tensor, target masking is not supported with sequence embeddings
450
+
451
+ Args:
452
+ emb_tensor: Sequence embedding tensor
453
+ max_tokens: Maximum number of tokens in the sequence
454
+ input_budget: Token budget for the input
455
+ target_budget: Token budget for the target (unused for now)
456
+ keep_scheme: Scheme for sampling the keep probability
457
+
458
+ Returns:
459
+ Dictionary containing the masked sequence embedding tensor, the input mask, the target mask, and the decoder attention mask
460
+ """
461
+ # Only supported as input modality now
462
+
463
+ # Make fake seq ids for sequence embeddings to reuse simple_span_masking function
464
+ fake_seq_ids = []
465
+ emb_dict = {}
466
+ id_num = len(self.sentinel_ids)
467
+ emb_ind = 0
468
+ while(len(fake_seq_ids) < len(emb_tensor)):
469
+ if id_num not in self.sentinel_ids: # replace with T5 sentinel_id
470
+ fake_seq_ids.append(id_num)
471
+ emb_dict[id_num] = emb_tensor[emb_ind, :]
472
+ emb_ind += 1
473
+ id_num += 1
474
+
475
+ # Truncate sequence
476
+ fake_seq_ids = fake_seq_ids[:max_tokens]
477
+
478
+ # If input budget is 0, treat it as if the whole sequence is completely masked
479
+ if input_budget == 0:
480
+ keep_prob = 0.
481
+ fake_input_seq_ids = []
482
+ _, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
483
+ else:
484
+ if keep_scheme == 'random':
485
+ keep_prob = sample_uniform(0, 1)
486
+ elif keep_scheme == 'all':
487
+ keep_prob = 1.0
488
+ elif keep_scheme == 'binary':
489
+ keep_prob = random.choice([0., 1.])
490
+ else:
491
+ raise ValueError(f"Invalid keep scheme for sequence masking: {keep_scheme}")
492
+
493
+ fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
494
+ # Keep lowering the keep_prob while we are over-budget
495
+ while len(fake_input_seq_ids) > input_budget:
496
+ keep_prob = keep_prob * self.keep_prob_decay_factor
497
+ fake_input_seq_ids, fake_target_seq_ids = simple_span_masking(fake_seq_ids, self.sentinel_to_id, keep_prob)
498
+
499
+ # Span masking can add up to max_tokens tokens for input
500
+ max_length = max_tokens
501
+ tensor = torch.zeros((max_length, emb_tensor.shape[1]), dtype=torch.float32)
502
+ input_mask = torch.ones(max_length, dtype=torch.bool)
503
+ target_mask = torch.ones(max_length, dtype=torch.bool)
504
+ decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
505
+
506
+ # Put tensor values back based on the fake seq ids
507
+ for i_, fake_id in enumerate(fake_input_seq_ids):
508
+ if fake_id in self.sentinel_ids:
509
+ tensor[i_, :] = torch.zeros_like(emb_tensor[0,:]) # TODO replace to learned embeddings later
510
+ else:
511
+ tensor[i_, :] = emb_dict[fake_id]
512
+
513
+ # Set input and input mask
514
+ input_mask[:len(fake_input_seq_ids)] = 0
515
+
516
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
517
+
518
+
519
+ def __call__(self, mod_dict):
520
+ """Applies input and target masking to a dictionary of modalities
521
+
522
+ Args:
523
+ mod_dict: Dictionary of modalities
524
+
525
+ Returns:
526
+ Dictionary containing the masked modalities
527
+ """
528
+ if self.sampling_weights is not None:
529
+ # Sample masking scheme according to a list of weights
530
+ dir_idx = torch.multinomial(self.sampling_weights, 1).item()
531
+ else:
532
+ # Randomly sample masking scheme
533
+ dir_idx = random.randint(0, self.num_dirichlets - 1)
534
+
535
+ num_input_tokens = random.randint(*self.input_tokens_range)
536
+ num_target_tokens = random.randint(*self.target_tokens_range) if self.target_tokens_range is not None else None
537
+
538
+ input_token_budget = self.input_token_budget(num_input_tokens, dir_idx)
539
+
540
+ if num_target_tokens is not None:
541
+ target_token_budget = self.target_token_budget(input_token_budget, num_target_tokens, dir_idx)
542
+ else:
543
+ target_token_budget = [None] * self.num_modalities
544
+
545
+ masked_mod_dict = {}
546
+ for (mod_name, mod_info), input_budget, target_budget in zip(self.modality_info.items(), input_token_budget, target_token_budget):
547
+ mod_type = mod_info['type']
548
+ mod_name_load = mod_name if mod_name in mod_dict else get_transform_key(mod_name)
549
+ if mod_type == 'img':
550
+ masked_mod_dict[mod_name] = self.image_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget)
551
+ elif mod_type == 'seq':
552
+ keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
553
+ masked_mod_dict[mod_name] = self.sequence_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
554
+ elif mod_type == 'seq_token':
555
+ keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
556
+ vocab_offset = mod_info.get('vocab_offset', 0) # Check if any space is allocated to sentinel tokens and other special tokens
557
+ masked_mod_dict[mod_name] = self.sequence_token_mask(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme, vocab_offset=vocab_offset)
558
+ elif mod_type == "seq_emb":
559
+ keep_scheme = 'random' if ('keep' not in mod_info) else mod_info['keep'][dir_idx]
560
+ masked_mod_dict[mod_name] = self.sequence_emb_mask_span(mod_dict[mod_name_load], mod_info['max_tokens'], input_budget, target_budget, keep_scheme)
561
+ else:
562
+ raise ValueError(f"Invalid modality type: {mod_type}")
563
+
564
+ return masked_mod_dict
565
+
566
+
567
+ class TransferMasking(object):
568
+ def __init__(self,
569
+ modality_info: Dict,
570
+ text_tokenizer: Optional[Tokenizer],
571
+ input_modalities: List[str],
572
+ target_modalities: List[str]):
573
+ """Performs masking for transfer on a dict of modalities (both image based and sequence based modalities),
574
+ by specifying which modalities are inputs and which are targets.
575
+
576
+ Args:
577
+ modality_info: Dict with the modalities and their corresponding information
578
+ text_tokenizer: Tokenizer to use for text modalities
579
+ input_modalities: List of modalities to use as input
580
+ target_modalities: List of modalities to use as target
581
+ """
582
+ self.modality_info = modality_info
583
+ self.num_modalities = len(modality_info)
584
+ self.min_tokens = torch.tensor([mod['min_tokens'] for mod in modality_info.values()])
585
+ self.max_tokens = torch.tensor([mod['max_tokens'] for mod in modality_info.values()])
586
+ self.mod_is_img = torch.tensor([mod['type'] == 'img' for mod in modality_info.values()])
587
+
588
+ self.input_modalities = set(input_modalities)
589
+ self.target_modalities = set(target_modalities)
590
+
591
+ # Tokenizer for text modalities
592
+ self.text_tokenizer = text_tokenizer
593
+ if self.text_tokenizer is not None:
594
+ self.keep_prob_decay_factor = 0.9
595
+ self.sentinel_to_id = get_sentinel_to_id_mapping(text_tokenizer)
596
+ self.sentinel_ids = set(self.sentinel_to_id.values())
597
+ self.pad_id = text_tokenizer.token_to_id("[PAD]")
598
+ self.eos_id = text_tokenizer.token_to_id("[EOS]")
599
+
600
+
601
+ def input_image(self, tensor: torch.Tensor, num_tokens: int):
602
+ """Applies masking for an image given as input
603
+
604
+ Args:
605
+ tensor: Image tensor
606
+ num_tokens: Number of tokens in the tensor
607
+
608
+ Returns:
609
+ Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
610
+ """
611
+
612
+ # Input mask
613
+ input_mask = torch.zeros(num_tokens, dtype=torch.bool)
614
+ # Target mask
615
+ target_mask = torch.ones(num_tokens, dtype=torch.bool)
616
+ # Decoder attention mask
617
+ decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
618
+
619
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
620
+
621
+ def target_image(self, tensor: torch.Tensor, num_tokens: int):
622
+ """Applies masking for an image given as target
623
+
624
+ Args:
625
+ tensor: Image tensor
626
+ num_tokens: Number of tokens in the tensor
627
+
628
+ Returns:
629
+ Dictionary containing the masked image tensor, the input mask, the target mask, and the decoder attention mask
630
+ """
631
+
632
+ # Input mask
633
+ input_mask = torch.ones(num_tokens, dtype=torch.bool)
634
+ # Target mask
635
+ target_mask = torch.zeros(num_tokens, dtype=torch.bool)
636
+ # Decoder attention mask
637
+ decoder_attention_mask = torch.zeros(num_tokens, dtype=torch.int)
638
+ decoder_attention_mask[0] = num_tokens
639
+
640
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
641
+
642
+
643
+ def input_sequence(self, sequence_str: str, max_tokens: int):
644
+ """Applies masking for a sequence given as input
645
+
646
+ Args:
647
+ sequence_str: Sequence string
648
+ max_tokens: Maximum number of tokens in the sequence
649
+
650
+ Returns:
651
+ Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
652
+ """
653
+ # Tokenize the text and get the ids
654
+ seq_ids = self.text_tokenizer.encode(sequence_str).ids
655
+ # Add EOS to all sequences
656
+ seq_ids.append(self.eos_id)
657
+ # Truncate sequence
658
+ seq_ids = seq_ids[:max_tokens]
659
+
660
+ keep_prob = 1.
661
+ input_seq_ids, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
662
+
663
+ # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
664
+ max_length = (max_tokens + 1) * 2
665
+ tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
666
+ input_mask = torch.ones(max_length, dtype=torch.bool)
667
+ target_mask = torch.ones(max_length, dtype=torch.bool)
668
+ decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
669
+
670
+ # Set input and input mask
671
+ tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
672
+ input_mask[:len(input_seq_ids)] = 0
673
+
674
+ tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
675
+ target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
676
+ decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1
677
+
678
+
679
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask, "decoder_attention_mask": decoder_attention_mask}
680
+
681
+
682
+ def target_sequence(self, sequence_str: str, max_tokens: int):
683
+ """Applies masking for a sequence given as target
684
+
685
+ Args:
686
+ sequence_str: Sequence string
687
+ max_tokens: Maximum number of tokens in the sequence
688
+
689
+ Returns:
690
+ Dictionary containing the masked sequence string, the input mask, the target mask, and the decoder attention mask
691
+ """
692
+ # Tokenize the text and get the ids
693
+ seq_ids = self.text_tokenizer.encode(sequence_str).ids
694
+ # Add EOS to all sequences
695
+ seq_ids.append(self.eos_id)
696
+ # Truncate sequence
697
+ seq_ids = seq_ids[:max_tokens]
698
+
699
+ keep_prob = 0.
700
+ input_seq_ids = []
701
+ _, target_seq_ids = simple_span_masking(seq_ids, self.sentinel_to_id, keep_prob)
702
+
703
+ # Span masking can add up to (max_tokens + 1) * 2 tokens for input + target
704
+ max_length = (max_tokens + 1) * 2
705
+ tensor = torch.ones(max_length, dtype=torch.int) * self.pad_id
706
+ input_mask = torch.ones(max_length, dtype=torch.bool)
707
+ target_mask = torch.ones(max_length, dtype=torch.bool)
708
+ decoder_attention_mask = torch.zeros(max_length, dtype=torch.int)
709
+
710
+ # Set input and input mask
711
+ tensor[:len(input_seq_ids)] = torch.tensor(input_seq_ids, dtype=torch.int)
712
+ input_mask[:len(input_seq_ids)] = 0
713
+
714
+ tensor[max_tokens:max_tokens + len(target_seq_ids)] = torch.tensor(target_seq_ids, dtype=torch.int)
715
+ target_mask[max_tokens:max_tokens + len(target_seq_ids)] = 0
716
+ decoder_attention_mask[max_tokens:max_tokens + len(target_seq_ids)] = 1
717
+
718
+ return {"tensor": tensor, "input_mask": input_mask, "target_mask": target_mask,
719
+ "decoder_attention_mask": decoder_attention_mask}
720
+
721
+ def __call__(self, mod_dict):
722
+ """Applies input and target masking to a dictionary of modalities
723
+
724
+ Args:
725
+ mod_dict: Dictionary of modalities
726
+
727
+ Returns:
728
+ Dictionary containing the masked modalities
729
+ """
730
+ masked_mod_dict = {}
731
+ for mod_name, mod_info in self.modality_info.items():
732
+ mod_type = mod_info['type']
733
+ if mod_type == 'img' and mod_name in self.input_modalities:
734
+ masked_mod_dict[mod_name] = self.input_image(mod_dict[mod_name], mod_info['max_tokens'])
735
+ elif mod_type == 'img' and mod_name in self.target_modalities:
736
+ masked_mod_dict[mod_name] = self.target_image(mod_dict[mod_name], mod_info['max_tokens'])
737
+ elif mod_type == 'seq' and mod_name in self.input_modalities:
738
+ masked_mod_dict[mod_name] = self.input_sequence(mod_dict[mod_name], mod_info['max_tokens'])
739
+ elif mod_type == 'seq' and mod_name in self.target_modalities:
740
+ masked_mod_dict[mod_name] = self.target_sequence(mod_dict[mod_name], mod_info['max_tokens'])
741
+ else:
742
+ raise ValueError(f"Invalid modality type: {mod_type} or modality name not in input or target modalities: {mod_name}")
743
+
744
+ if 'mask_valid' in mod_dict:
745
+ masked_mod_dict['mask_valid'] = mod_dict['mask_valid']
746
+
747
+ return masked_mod_dict
fourm/data/modality_info.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+
16
+ import fourm.utils.data_constants as data_constants
17
+ from fourm.data.modality_transforms import (CaptionTransform, DepthTransform,
18
+ DetectionTransform, MaskTransform,
19
+ NormalTransform, RGBTransform,
20
+ SemsegTransform, TokTransform,
21
+ CaptionEmbTransform, MetadataTransform,
22
+ HumanPoseTransform, ColorPaletteTransform,
23
+ SAMInstanceTokTransform, SAMInstanceTransform)
24
+ from fourm.models.decoder_embeddings import (ImageTokenDecoderEmbedding,
25
+ SequenceDecoderEmbedding)
26
+ from fourm.models.encoder_embeddings import (ImageEncoderEmbedding,
27
+ ImageTokenEncoderEmbedding,
28
+ SequenceEncoderEmbedding,
29
+ SequenceEmbEncoderEmbedding)
30
+ from fourm.utils import generate_uint15_hash
31
+
32
+ MODALITY_INFO = {
33
+ # 4M-7 modalities
34
+ 'rgb@224': {
35
+ 'input_size': 224,
36
+ 'patch_size': 16,
37
+ 'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3),
38
+ 'decoder_embedding': None,
39
+ 'min_tokens': 0,
40
+ 'max_tokens': None, # Will be set to 196
41
+ 'type': 'img',
42
+ 'num_channels': 3,
43
+ 'id': generate_uint15_hash('rgb@224'),
44
+ 'path': 'rgb',
45
+ },
46
+ 'rgb': { # used for tokenizer training
47
+ 'type': 'img',
48
+ 'num_channels': 3,
49
+ 'id': generate_uint15_hash('rgb'),
50
+ 'path': 'rgb',
51
+ },
52
+ 'caption': {
53
+ 'vocab_size': 30_000,
54
+ 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
55
+ 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
56
+ 'min_tokens': 0,
57
+ 'max_tokens': 256,
58
+ 'type': 'seq',
59
+ 'id': generate_uint15_hash('caption'),
60
+ },
61
+ 'det': {
62
+ 'vocab_size': 30_000,
63
+ 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
64
+ 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
65
+ 'min_tokens': 0,
66
+ 'max_tokens': 256,
67
+ 'type': 'seq',
68
+ 'id': generate_uint15_hash('det'),
69
+ },
70
+ 'tok_rgb@224': {
71
+ 'input_size': 224,
72
+ 'patch_size': 16,
73
+ 'vocab_size': 16384,
74
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384),
75
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384),
76
+ 'min_tokens': 0,
77
+ 'max_tokens': None, # Will be set to 196
78
+ 'type': 'img',
79
+ 'id': generate_uint15_hash('tok_rgb@224'),
80
+ 'pretokenized': True,
81
+ },
82
+ 'tok_depth@224': {
83
+ 'input_size': 224,
84
+ 'patch_size': 16,
85
+ 'vocab_size': 8192,
86
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
87
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
88
+ 'min_tokens': 0,
89
+ 'max_tokens': None, # Will be set to 196
90
+ 'type': 'img',
91
+ 'id': generate_uint15_hash('tok_depth@224'),
92
+ 'pretokenized': True,
93
+ },
94
+ 'depth': { # used for tokenizer training
95
+ 'type': 'img',
96
+ 'num_channels': 1,
97
+ 'id': generate_uint15_hash('depth'),
98
+ },
99
+ 'tok_normal@224': {
100
+ 'input_size': 224,
101
+ 'patch_size': 16,
102
+ 'vocab_size': 8192,
103
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
104
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
105
+ 'min_tokens': 0,
106
+ 'max_tokens': None, # Will be set to 196
107
+ 'type': 'img',
108
+ 'id': generate_uint15_hash('tok_normal@224'),
109
+ 'pretokenized': True,
110
+ },
111
+ 'normal': { # used for tokenizer training
112
+ 'type': 'img',
113
+ 'num_channels': 3,
114
+ 'id': generate_uint15_hash('normal'),
115
+ },
116
+ 'tok_semseg@224': {
117
+ 'input_size': 224,
118
+ 'patch_size': 16,
119
+ 'vocab_size': 4096,
120
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096),
121
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096),
122
+ 'min_tokens': 0,
123
+ 'max_tokens': None, # Will be set to 196
124
+ 'type': 'img',
125
+ 'id': generate_uint15_hash('tok_semseg@224'),
126
+ 'pretokenized': True,
127
+ },
128
+ 'semseg_coco': { # used for tokenizer training
129
+ 'type': 'img',
130
+ 'num_channels': 64,
131
+ 'num_labels': data_constants.COCO_SEMSEG_NUM_CLASSES,
132
+ 'id': generate_uint15_hash('semseg_coco'),
133
+ },
134
+ 'tok_clip@224': {
135
+ 'input_size': 224,
136
+ 'patch_size': 16,
137
+ 'vocab_size': 8192,
138
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
139
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
140
+ 'min_tokens': 0,
141
+ 'max_tokens': None, # Will be set to 196
142
+ 'type': 'img',
143
+ 'id': generate_uint15_hash('tok_clip@224'),
144
+ 'pretokenized': True,
145
+ },
146
+ 'CLIP-B16': { # used for tokenizer training
147
+ 'type': 'feature_map',
148
+ 'num_channels': 512,
149
+ 'id': generate_uint15_hash('CLIP-B16'),
150
+ },
151
+
152
+ # 4M-21 modalities
153
+ 't5_caption': {
154
+ 'encoder_embedding': partial(SequenceEmbEncoderEmbedding, max_length=77, padding_idx=0),
155
+ 'decoder_embedding': None,
156
+ 'min_tokens': 0,
157
+ 'max_tokens': 77,
158
+ 'type': 'seq_emb',
159
+ 'id': generate_uint15_hash('t5_caption'),
160
+ },
161
+ 'metadata': {
162
+ 'vocab_size': 30_000,
163
+ 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True),
164
+ 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True),
165
+ 'min_tokens': 0,
166
+ 'max_tokens': 40, # At most 2x19=38 for 19 metadata types, +1 for EOS, +1 for sentinel
167
+ 'type': 'seq',
168
+ 'id': generate_uint15_hash('metadata'),
169
+ 'shared_vocab': ['caption'],
170
+ 'path': 'metadata',
171
+ },
172
+ 'human_poses': {
173
+ 'vocab_size': 30_000,
174
+ 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True),
175
+ 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True),
176
+ 'min_tokens': 0,
177
+ 'max_tokens': 275, #7*39+1 EOS+1 S_1#263, #261 in one of the models, or 263 to have EOS #261+1+1 #238,
178
+ 'type': 'seq',
179
+ 'num_channels': 207, # for tokenization training, only the pose part is needed
180
+ 'id': generate_uint15_hash('human_poses'),
181
+ 'shared_vocab': ['caption'],
182
+ },
183
+ 'color_palette': {
184
+ 'vocab_size': 30_000,
185
+ 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True),
186
+ 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True),
187
+ 'min_tokens': 0,
188
+ 'max_tokens': 23, #7x3=21 for 7 colors, +1 for EOS, +1 for sentinel
189
+ 'type': 'seq',
190
+ 'id': generate_uint15_hash('color_palette'),
191
+ 'shared_vocab': ['caption'],
192
+ 'path': 'color_palette',
193
+ },
194
+ 'sam_mask': {
195
+ 'encoder_embedding': None,
196
+ 'decoder_embedding': None,
197
+ 'min_tokens': 0,
198
+ 'max_tokens': 64,
199
+ 'type': 'img',
200
+ 'num_channels': 1,
201
+ 'id': generate_uint15_hash('sam_mask'),
202
+ },
203
+ 'sam_instance': {
204
+ 'vocab_size': 30_000,
205
+ 'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True),
206
+ 'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True),
207
+ 'min_tokens': 0,
208
+ 'max_tokens': 290,
209
+ 'type': 'seq',
210
+ 'id': generate_uint15_hash('sam_instance'),
211
+ 'shared_vocab': ['caption'],
212
+ 'pretokenized': True,
213
+ },
214
+ 'tok_canny_edge@224': {
215
+ 'input_size': 224,
216
+ 'patch_size': 16,
217
+ 'vocab_size': 8192,
218
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
219
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
220
+ 'min_tokens': 0,
221
+ 'max_tokens': None, # Will be set to 196
222
+ 'type': 'img',
223
+ 'id': generate_uint15_hash('tok_canny_edge@224'),
224
+ 'pretokenized': True,
225
+ },
226
+ 'canny_edge': { # used for tokenizer training
227
+ 'type': 'img',
228
+ 'num_channels': 1,
229
+ 'id': generate_uint15_hash('canny_edge'),
230
+ },
231
+ 'tok_sam_edge@224': {
232
+ 'input_size': 224,
233
+ 'patch_size': 16,
234
+ 'vocab_size': 8192,
235
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
236
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
237
+ 'min_tokens': 0,
238
+ 'max_tokens': None, # Will be set to 196
239
+ 'type': 'img',
240
+ 'id': generate_uint15_hash('tok_sam_edge@224'),
241
+ 'pretokenized': True,
242
+ },
243
+ 'tok_dinov2@224': {
244
+ 'input_size': 224,
245
+ 'patch_size': 14,
246
+ 'vocab_size': 8192,
247
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
248
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
249
+ 'min_tokens': 0,
250
+ 'max_tokens': None, # Will be set to 256
251
+ 'type': 'img',
252
+ 'id': generate_uint15_hash('tok_dinov2@224'),
253
+ 'pretokenized': True,
254
+ },
255
+ 'DINOv2-B14': { # used for tokenizer training
256
+ 'type': 'feature_map',
257
+ 'num_channels': 768,
258
+ 'id': generate_uint15_hash('DINOv2-B14'),
259
+ },
260
+ 'tok_imagebind@224': {
261
+ 'input_size': 224,
262
+ 'patch_size': 14,
263
+ 'vocab_size': 8192,
264
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
265
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
266
+ 'min_tokens': 0,
267
+ 'max_tokens': None, # Will be set to 256
268
+ 'type': 'img',
269
+ 'id': generate_uint15_hash('tok_imagebind@224'),
270
+ 'pretokenized': True,
271
+ },
272
+ 'ImageBind-H14': { # used for tokenizer training
273
+ 'type': 'feature_map',
274
+ 'num_channels': 1280,
275
+ 'id': generate_uint15_hash('ImageBind-H14'),
276
+ },
277
+ 'tok_dinov2_global': {
278
+ 'vocab_size': 8192,
279
+ 'patch_size': 56,
280
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
281
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
282
+ 'min_tokens': 0,
283
+ 'max_tokens': 16,
284
+ 'type': 'img',
285
+ 'id': generate_uint15_hash('tok_dinov2_global'),
286
+ 'pretokenized': True,
287
+ },
288
+ 'DINOv2-B14-global': { # used for tokenizer training
289
+ 'type': 'feature_map',
290
+ 'num_channels': 768,
291
+ 'id': generate_uint15_hash('DINOv2-B14-global'),
292
+ },
293
+ 'tok_imagebind_global': {
294
+ 'vocab_size': 8192,
295
+ 'patch_size': 56,
296
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
297
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
298
+ 'min_tokens': 0,
299
+ 'max_tokens': 16,
300
+ 'type': 'img',
301
+ 'id': generate_uint15_hash('tok_imagebind_global'),
302
+ 'pretokenized': True,
303
+ },
304
+ 'ImageBind-H14-global': { # used for tokenizer training
305
+ 'type': 'feature_map',
306
+ 'num_channels': 1280,
307
+ 'id': generate_uint15_hash('ImageBind-H14-global'),
308
+ },
309
+
310
+ ### 224->448 super resolution modalities
311
+ 'rgb@448': {
312
+ 'input_size': 448,
313
+ 'patch_size': 16,
314
+ 'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3),
315
+ 'decoder_embedding': None,
316
+ 'min_tokens': 0,
317
+ 'max_tokens': None, # Will be set to 784
318
+ 'type': 'img',
319
+ 'num_channels': 3,
320
+ 'id': generate_uint15_hash('rgb@448'),
321
+ 'path': 'rgb',
322
+ },
323
+ 'tok_rgb@448': {
324
+ 'input_size': 448,
325
+ 'patch_size': 16,
326
+ 'vocab_size': 16384,
327
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384),
328
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384),
329
+ 'min_tokens': 0,
330
+ 'max_tokens': None, # Will be set to 784
331
+ 'type': 'img',
332
+ 'id': generate_uint15_hash('tok_rgb@448'),
333
+ 'pretokenized': True,
334
+ },
335
+ 'tok_depth@448': {
336
+ 'input_size': 448,
337
+ 'patch_size': 16,
338
+ 'vocab_size': 8192,
339
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
340
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
341
+ 'min_tokens': 0,
342
+ 'max_tokens': None, # Will be set to 784
343
+ 'type': 'img',
344
+ 'id': generate_uint15_hash('tok_depth@448'),
345
+ 'pretokenized': True,
346
+ },
347
+ 'tok_normal@448': {
348
+ 'input_size': 448,
349
+ 'patch_size': 16,
350
+ 'vocab_size': 8192,
351
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
352
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
353
+ 'min_tokens': 0,
354
+ 'max_tokens': None, # Will be set to 784
355
+ 'type': 'img',
356
+ 'id': generate_uint15_hash('tok_normal@448'),
357
+ 'pretokenized': True,
358
+ },
359
+ 'tok_semseg@448': {
360
+ 'input_size': 448,
361
+ 'patch_size': 16,
362
+ 'vocab_size': 4096,
363
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096),
364
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096),
365
+ 'min_tokens': 0,
366
+ 'max_tokens': None, # Will be set to 784
367
+ 'type': 'img',
368
+ 'id': generate_uint15_hash('tok_semseg@448'),
369
+ 'pretokenized': True,
370
+ },
371
+ 'tok_clip@448': {
372
+ 'input_size': 448,
373
+ 'patch_size': 16,
374
+ 'vocab_size': 8192,
375
+ 'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
376
+ 'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
377
+ 'min_tokens': 0,
378
+ 'max_tokens': None, # Will be set to 784
379
+ 'type': 'img',
380
+ 'id': generate_uint15_hash('tok_clip@448'),
381
+ 'pretokenized': True,
382
+ },
383
+ }
384
+
385
+ # Note: @res suffix is ignored for modality transforms
386
+ MODALITY_TRANSFORMS = {
387
+ # 4M-7 modalities
388
+ 'rgb': RGBTransform(imagenet_default_mean_and_std=True),
389
+ 'caption': CaptionTransform(aligned_captions=True),
390
+ 'det': DetectionTransform(det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0),
391
+ 'tok_rgb': TokTransform(),
392
+ 'tok_depth': TokTransform(),
393
+ 'tok_normal': TokTransform(),
394
+ 'tok_semseg': TokTransform(),
395
+ 'tok_clip': TokTransform(),
396
+ # 4M-21 modalities
397
+ 't5_caption': CaptionEmbTransform(),
398
+ 'metadata': MetadataTransform(special_vmin=0, special_vmax=999, shuffle=True, random_trunc=False, return_chunks=True),
399
+ 'human_poses': HumanPoseTransform(coord_bins=1000),
400
+ 'color_palette': ColorPaletteTransform(coord_bins=1000),
401
+ 'sam_instance': SAMInstanceTokTransform(image_size=224, points_per_side=7, point_order='random'),
402
+ 'tok_canny_edge': TokTransform(),
403
+ 'tok_sam_edge': TokTransform(),
404
+ 'tok_dinov2': TokTransform(),
405
+ 'tok_imagebind': TokTransform(),
406
+ 'tok_dinov2_global': TokTransform(),
407
+ 'tok_imagebind_global': TokTransform(),
408
+ # Other
409
+ 'mask_valid': MaskTransform(mask_pool_size=1),
410
+ }
411
+
412
+ MODALITY_TRANSFORMS_DIVAE = {
413
+ 'rgb': RGBTransform(imagenet_default_mean_and_std=False),
414
+ 'depth': DepthTransform(standardize_depth=True),
415
+ 'normal': NormalTransform(standardize_surface_normals=False),
416
+ 'mask_valid': MaskTransform(mask_pool_size=1),
417
+ 'semseg_coco': SemsegTransform(shift_idx_by_one=True),
418
+ 'canny_edge': RGBTransform(imagenet_default_mean_and_std=False),
419
+ 'human_poses': HumanPoseTransform(coord_bins=1000, only_pose=True),
420
+ 'sam_mask': SAMInstanceTransform(mask_size=64, max_instance_n=1),
421
+ }
422
+
423
+ MODALITY_TRANSFORMS_VQCONTROLNET = {
424
+ 'rgb': RGBTransform(imagenet_default_mean_and_std=False),
425
+ 'mask_valid': MaskTransform(mask_pool_size=1),
426
+ 'caption': CaptionTransform(aligned_captions=True),
427
+ }
fourm/data/modality_transforms.py ADDED
@@ -0,0 +1,1387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gzip
15
+ import json
16
+ import random
17
+ from pathlib import Path
18
+ from typing import Optional, Tuple, List, Dict
19
+ from abc import ABC, abstractmethod
20
+
21
+ from PIL import Image
22
+ import cv2
23
+
24
+ import albumentations as A
25
+ import numpy as np
26
+ import torch
27
+ import torchvision.transforms.functional as TF
28
+ import torchvision.transforms as T
29
+ from einops import rearrange, repeat, reduce
30
+
31
+ from fourm.utils import to_2tuple
32
+ from fourm.utils.data_constants import (IMAGENET_DEFAULT_MEAN,
33
+ IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN,
34
+ IMAGENET_SURFACE_NORMAL_STD, IMAGENET_SURFACE_NORMAL_MEAN,
35
+ IMAGENET_INCEPTION_STD, SEG_IGNORE_INDEX, PAD_MASK_VALUE)
36
+
37
+
38
+ # The @-symbol is used to specify the resolution of a modality. Syntax: modality@resolution
39
+ def get_transform_key(mod_name):
40
+ return mod_name.split('@')[0]
41
+
42
+ def get_transform_resolution(mod_name, default_resolution, to_tuple=True):
43
+ res = int(mod_name.split('@')[1]) if '@' in mod_name else default_resolution
44
+ return to_2tuple(res) if to_tuple else res
45
+
46
+ def get_transform(mod_name, transforms_dict):
47
+ return transforms_dict.get(get_transform_key(mod_name), IdentityTransform())
48
+
49
+ def get_pil_resample_mode(resample_mode: str):
50
+ """
51
+ Returns the PIL resampling mode for the given resample mode string.
52
+
53
+ Args:
54
+ resample_mode: Resampling mode string
55
+ """
56
+ if resample_mode is None:
57
+ return None
58
+ elif resample_mode == "bilinear":
59
+ return Image.Resampling.BILINEAR if hasattr(Image, 'Resampling') else Image.BILINEAR
60
+ elif resample_mode == "bicubic":
61
+ return Image.Resampling.BICUBIC if hasattr(Image, 'Resampling') else Image.BICUBIC
62
+ elif resample_mode == "nearest":
63
+ return Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST
64
+ else:
65
+ raise ValueError(f"Resample mode {resample_mode} is not supported.")
66
+
67
+ class UnifiedDataTransform(object):
68
+ def __init__(self, transforms_dict, image_augmenter, resample_mode: str = None, add_sizes: bool = False, **kwargs):
69
+ """Unified data augmentation for FourM
70
+
71
+ Args:
72
+ transforms_dict (dict): Dict of transforms for each modality
73
+ image_augmenter (AbstractImageAugmenter): Image augmenter
74
+ resample_mode (str, optional): Resampling mode for PIL images (default: None -> uses default resampling mode for data type)
75
+ One out of ["bilinear", "bicubic", "nearest", None].
76
+ add_sizes (bool, optional): Whether to add crop coordinates and original size to the output dict
77
+ """
78
+
79
+ self.transforms_dict = transforms_dict
80
+ self.image_augmenter = image_augmenter
81
+ self.resample_mode = resample_mode
82
+ self.add_sizes = add_sizes
83
+
84
+ def unified_image_augment(self, mod_dict, crop_settings):
85
+ """Apply the image augmenter to all modalities where it is applicable
86
+
87
+ Args:
88
+ mod_dict (dict): Dict of modalities
89
+ crop_settings (dict): Crop settings
90
+
91
+ Returns:
92
+ dict: Transformed dict of modalities
93
+ """
94
+
95
+ crop_coords, flip, orig_size, target_size, rand_aug_idx = self.image_augmenter(mod_dict, crop_settings)
96
+
97
+ mod_dict = {
98
+ k: self.transforms_dict[get_transform_key(k)].image_augment(
99
+ v, crop_coords=crop_coords, flip=flip, orig_size=orig_size,
100
+ target_size=get_transform_resolution(k, target_size), rand_aug_idx=rand_aug_idx,
101
+ resample_mode=self.resample_mode
102
+ )
103
+ for k, v in mod_dict.items()
104
+ }
105
+
106
+ if self.add_sizes:
107
+ mod_dict["crop_coords"] = torch.tensor(crop_coords)
108
+ mod_dict["orig_size"] = torch.tensor(orig_size)
109
+
110
+ return mod_dict
111
+
112
+ def __call__(self, mod_dict):
113
+ """Apply the augmentation to a dict of modalities (both image based and sequence based modalities)
114
+
115
+ Args:
116
+ mod_dict (dict): Dict of modalities
117
+
118
+ Returns:
119
+ dict: Transformed dict of modalities
120
+ """
121
+ crop_settings = mod_dict.pop("crop_settings", None)
122
+
123
+ mod_dict = {k: get_transform(k, self.transforms_dict).preprocess(v) for k, v in mod_dict.items()}
124
+
125
+ mod_dict = self.unified_image_augment(mod_dict, crop_settings)
126
+
127
+ mod_dict = {k: get_transform(k, self.transforms_dict).postprocess(v) for k, v in mod_dict.items()}
128
+
129
+ return mod_dict
130
+
131
+ def __repr__(self):
132
+ repr = "(UnifiedDataAugmentation,\n"
133
+ repr += ")"
134
+ return repr
135
+
136
+
137
+ class AbstractTransform(ABC):
138
+
139
+ @abstractmethod
140
+ def load(self, sample):
141
+ pass
142
+
143
+ @abstractmethod
144
+ def preprocess(self, sample):
145
+ pass
146
+
147
+ @abstractmethod
148
+ def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
149
+ rand_aug_idx: Optional[int], resample_mode: str = None):
150
+ pass
151
+
152
+ @abstractmethod
153
+ def postprocess(self, v):
154
+ pass
155
+
156
+
157
+ class ImageTransform(AbstractTransform):
158
+
159
+ @staticmethod
160
+ def pil_loader(path: str) -> Image.Image:
161
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
162
+ # with open(path, 'rb') as f:
163
+ # img = Image.open(f)
164
+ img = Image.open(path)
165
+ return img
166
+
167
+
168
+ @staticmethod
169
+ def image_hflip(img: Image, flip: bool):
170
+ """Crop and resize an image
171
+
172
+ :param img: Image to crop and resize
173
+ :param flip: Whether to flip the image
174
+ :return: Flipped image (if flip = True)
175
+ """
176
+ if flip:
177
+ img = TF.hflip(img)
178
+ return img
179
+
180
+ @staticmethod
181
+ def image_crop_and_resize(img: Image, crop_coords: Tuple, target_size: Tuple, resample_mode: str = None):
182
+ """Crop and resize an image
183
+
184
+ :param img: Image to crop and resize
185
+ :param crop_coords: Coordinates of the crop (top, left, h, w)
186
+ :param target_size: Coordinates of the resize (height, width)
187
+ :return: Cropped and resized image
188
+ """
189
+
190
+ top, left, h, w = crop_coords
191
+ resize_height, resize_width = target_size
192
+ img = TF.crop(img, top, left, h, w)
193
+ resample_mode = get_pil_resample_mode(resample_mode)
194
+ img = img.resize((resize_height, resize_width), resample=resample_mode)
195
+ return img
196
+
197
+
198
+ class RGBTransform(ImageTransform):
199
+
200
+ def __init__(self, imagenet_default_mean_and_std=True, color_jitter=False, color_jitter_strength=0.5):
201
+ self.rgb_mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
202
+ self.rgb_std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
203
+ self.color_jitter = color_jitter
204
+ self.color_jitter_transform = self.random_color_jitter(color_jitter_strength)
205
+
206
+ def random_color_jitter(self, strength=0.5):
207
+ # Color Jitter from Pix2Seq and SimCLR
208
+ # Source: https://github.com/google-research/pix2seq/blob/main/data/data_utils.py#L114
209
+ t = T.Compose([
210
+ T.RandomApply([T.ColorJitter(brightness=0.8 * strength, contrast=0.8 * strength, saturation=0.8 * strength, hue=0.2 * strength)], p=0.8),
211
+ T.RandomApply([T.Grayscale(num_output_channels=3)], p=0.2),
212
+ ])
213
+
214
+ return t
215
+
216
+ def rgb_to_tensor(self, img):
217
+ img = TF.to_tensor(img)
218
+ img = TF.normalize(img, mean=self.rgb_mean, std=self.rgb_std)
219
+ return img
220
+
221
+ def load(self, path):
222
+ # TODO: Instead of converting to RGB here, do it either in the preprocess or the postprocess step. Makes it compatible with wds dataloading.
223
+ sample = self.pil_loader(path)
224
+ return sample
225
+
226
+ def preprocess(self, sample):
227
+ sample = sample.convert('RGB')
228
+
229
+ if self.color_jitter:
230
+ sample = self.color_jitter_transform(sample)
231
+
232
+ return sample
233
+
234
+ def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
235
+ rand_aug_idx: Optional[int], resample_mode: str = None):
236
+ img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode)
237
+ img = self.image_hflip(img, flip)
238
+ return img
239
+
240
+ def postprocess(self, sample):
241
+ sample = self.rgb_to_tensor(sample)
242
+ return sample
243
+
244
+
245
+ class DepthTransform(ImageTransform):
246
+
247
+ def __init__(self, standardize_depth=True):
248
+ self.standardize_depth = standardize_depth
249
+
250
+ def depth_to_tensor(self, img):
251
+ img = torch.Tensor( img / (2 ** 16 - 1.0) )
252
+ img = img.unsqueeze(0) # 1 x H x W
253
+ if self.standardize_depth:
254
+ img = self.truncated_depth_standardization(img)
255
+ return img
256
+
257
+ @staticmethod
258
+ def truncated_depth_standardization(depth, thresh: float = 0.1):
259
+ """Truncated depth standardization
260
+
261
+ :param depth: Depth map
262
+ :param thresh: Threshold
263
+ :return: Robustly standardized depth map
264
+ """
265
+ # Flatten depth and remove bottom and top 10% of values
266
+ trunc_depth = torch.sort(depth.reshape(-1), dim=0)[0]
267
+ trunc_depth = trunc_depth[int(thresh * trunc_depth.shape[0]): int((1 - thresh) * trunc_depth.shape[0])]
268
+ return (depth - trunc_depth.mean()) / torch.sqrt(trunc_depth.var() + 1e-6)
269
+
270
+ def load(self, path):
271
+ sample = self.pil_loader(path)
272
+ return sample
273
+
274
+ def preprocess(self, sample):
275
+ return sample
276
+
277
+ def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
278
+ rand_aug_idx: Optional[int], resample_mode: str = None):
279
+ img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode)
280
+ img = self.image_hflip(img, flip)
281
+ return img
282
+
283
+ def postprocess(self, sample):
284
+ sample = np.array(sample)
285
+ sample = self.depth_to_tensor(sample)
286
+ return sample
287
+
288
+
289
+ class NormalTransform(ImageTransform):
290
+
291
+ def __init__(self, standardize_surface_normals=False):
292
+ self.normal_mean = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_MEAN
293
+ self.normal_std = (0.5, 0.5, 0.5) if not standardize_surface_normals else IMAGENET_SURFACE_NORMAL_STD
294
+
295
+ def normal_to_tensor(self, img):
296
+ img = TF.to_tensor(img)
297
+ img = TF.normalize(img, mean=self.normal_mean, std=self.normal_std)
298
+ return img
299
+
300
+ def load(self, path):
301
+ sample = self.pil_loader(path)
302
+ return sample
303
+
304
+ def preprocess(self, sample):
305
+ return sample
306
+
307
+ def image_hflip(self, img: Image, flip: bool):
308
+ if flip:
309
+ img = TF.hflip(img)
310
+ flipped_np = np.array(img)
311
+ flipped_np[:, :, 0] = 255 - flipped_np[:, :, 0]
312
+ img = Image.fromarray(flipped_np)
313
+
314
+ return img
315
+
316
+ def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
317
+ rand_aug_idx: Optional[int], resample_mode: str = None):
318
+ img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode=resample_mode)
319
+ img = self.image_hflip(img, flip)
320
+ return img
321
+
322
+ def postprocess(self, sample):
323
+ sample = self.normal_to_tensor(sample)
324
+ return sample
325
+
326
+
327
+ class SemsegTransform(ImageTransform):
328
+
329
+ def __init__(self, scale_factor=1.0, shift_idx_by_one=False, id_mapping: Optional[Dict] = None, select_channel=None):
330
+ self.scale_factor = scale_factor
331
+ self.shift_idx_by_one = shift_idx_by_one
332
+ self.id_mapping = id_mapping
333
+ self.select_channel = select_channel
334
+
335
+ def map_semseg_values(self, sample):
336
+ sample = np.asarray(sample)
337
+ mapping_fn = lambda x: self.id_mapping.get(x, x)
338
+ sample = np.vectorize(mapping_fn)(sample)
339
+ sample = Image.fromarray(sample, mode='P')
340
+ return sample
341
+
342
+ def semseg_to_tensor(self, img):
343
+ # Rescale to scale factor
344
+ if self.scale_factor != 1.0:
345
+ target_height, target_width = int(img.height * self.scale_factor), int(img.width * self.scale_factor)
346
+ img = img.resize((target_width, target_height))
347
+ # Using pil_to_tensor keeps it in uint8, to_tensor converts it to float (rescaled to [0, 1])
348
+ img = TF.pil_to_tensor(img).to(torch.long).squeeze(0)
349
+ # 255->0, 254->0, all else shifted up by one
350
+ return img
351
+
352
+ def load(self, path):
353
+ sample = self.pil_loader(path)
354
+ if self.select_channel is not None:
355
+ sample = sample.split()[self.select_channel]
356
+ return sample
357
+
358
+ def preprocess(self, sample):
359
+ sample = sample.convert('P')
360
+
361
+ if self.id_mapping is not None:
362
+ sample = self.map_semseg_values(sample)
363
+
364
+ if self.shift_idx_by_one:
365
+ sample = np.asarray(sample)
366
+ sample = sample + 1
367
+ sample = Image.fromarray(sample, mode='P')
368
+
369
+ return sample
370
+
371
+ def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
372
+ rand_aug_idx: Optional[int], resample_mode: str = None):
373
+ # Value for padding with TF.crop is always 0.
374
+ # Override resampling mode to 'nearest' for semseg
375
+ img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest')
376
+ img = self.image_hflip(img, flip)
377
+ return img
378
+
379
+ def postprocess(self, sample):
380
+ img = self.semseg_to_tensor(sample)
381
+ return img
382
+
383
+
384
+ class SAMInstanceTransform(AbstractTransform):
385
+
386
+ def __init__(self, mask_size=64, max_instance_n=20, bbox_area_threshold=0.0005):
387
+ self.mask_size = mask_size
388
+ self.max_instance_n = max_instance_n
389
+ self.bbox_area_threshold = bbox_area_threshold
390
+
391
+ def get_bbox(self, instance):
392
+ """ Gets bounding box of the given instance
393
+ """
394
+ min_h, max_h = instance[:,:,1].min(), instance[:,:,1].max()
395
+ min_w, max_w = instance[:,:,0].min(), instance[:,:,0].max()
396
+ return [min_h, min_w, max_h, max_w]
397
+
398
+ def extend_instance_points(self, instance, border_fn):
399
+ """ Given an instance and a border function `border_fn`, extends the instance points with crossing points between the instance and
400
+ the crop borders. The crossing points are obtained using border_fn.
401
+ """
402
+ p = instance[:,0]
403
+ p_next = np.roll(p, (-1), axis=(0))
404
+ final_points = []
405
+ for x, xn in zip(p, p_next):
406
+ final_points.append(x)
407
+ for r in border_fn(x, xn):
408
+ final_points.append(r.astype(np.int32))
409
+ p = np.stack(final_points)
410
+ return p[:,None]
411
+
412
+ def remove_redundant_lines(self, orig_instance, instance):
413
+ """ Removes the redundant lines added during cropping.
414
+ """
415
+ final_points = []
416
+ for p in instance:
417
+ distance = cv2.pointPolygonTest(orig_instance, (p[0,0].item(), p[0,1].item()), measureDist=True)
418
+ if distance >= 0:
419
+ final_points.append(p[0])
420
+ return np.stack(final_points)[:,None]
421
+
422
+ def get_border_functions(self, crop_points):
423
+ """ Creates and returns a function `fn` using crop region coordinates given in crop_points.
424
+ `fn` receives two input points x and xn and returns all the crossing points between the line connecting
425
+ x and xn, and the borders of the cropping rectangle.
426
+ """
427
+ p = crop_points[:,0]
428
+ p_next = np.roll(p, (-1), axis=(0))
429
+ def fn(x, xn):
430
+ output = []
431
+ c_diff = p_next - p
432
+ x_diff = x - xn
433
+ for diff, c in zip(c_diff, p):
434
+ A = np.array([
435
+ [diff[0], x_diff[0]],
436
+ [diff[1], x_diff[1]]
437
+ ])
438
+ b = x - c
439
+ try:
440
+ lmbda = np.linalg.solve(A, b)
441
+ if 0 <= lmbda[0] <= 1 and 0 <= lmbda[1] <= 1:
442
+ output.append(lmbda[1] * xn + (1-lmbda[1]) * x)
443
+ except:
444
+ continue
445
+ return output
446
+ return fn
447
+
448
+ def crop_sample(self, sample, crop_coords):
449
+ """ Crop the sample using crop coordinates.
450
+ """
451
+ top, left, h, w = crop_coords
452
+ crop_region = (left, top, left + w, top + h)
453
+ crop_points = np.array([
454
+ [crop_region[0], crop_region[1]],
455
+ [crop_region[2], crop_region[1]],
456
+ [crop_region[2], crop_region[3]],
457
+ [crop_region[0], crop_region[3]],
458
+ ])[:,None]
459
+ border_functions = self.get_border_functions(crop_points)
460
+ cropped_sample = []
461
+ for instance in sample:
462
+ instance = self.extend_instance_points(instance, border_functions)
463
+ filter_condition = (
464
+ (instance[:, :, 0] > crop_region[0]) &
465
+ (instance[:, :, 0] < crop_region[2]) &
466
+ (instance[:, :, 1] > crop_region[1]) &
467
+ (instance[:, :, 1] < crop_region[3])
468
+ )
469
+ if not np.any(filter_condition):
470
+ continue
471
+
472
+ instance_copy = instance.copy()
473
+ instance_copy[:, :, 0] = np.clip(instance[:, :, 0], a_min=crop_region[0], a_max=crop_region[2])
474
+ instance_copy[:, :, 1] = np.clip(instance[:, :, 1], a_min=crop_region[1], a_max=crop_region[3])
475
+ instance_copy = self.remove_redundant_lines(instance, instance_copy)
476
+ instance_copy[:, :, 0] -= crop_region[0]
477
+ instance_copy[:, :, 1] -= crop_region[1]
478
+
479
+ cropped_sample.append(instance_copy)
480
+ return cropped_sample
481
+
482
+ def resize_sample(self, sample, original_size, target_size):
483
+ """ Resize the sample
484
+ """
485
+ width_scale = target_size[1] / original_size[1]
486
+ height_scale = target_size[0] / original_size[0]
487
+ resized_sample = []
488
+ for instance in sample:
489
+ instance_copy = instance.copy()
490
+ instance_copy[:, :, 0] = np.round(width_scale * instance_copy[:, :, 0])
491
+ instance_copy[:, :, 1] = np.round(height_scale * instance_copy[:, :, 1])
492
+ resized_sample.append(instance_copy)
493
+ return resized_sample
494
+
495
+ def remove_tiny_instances(self, sample, image_size):
496
+ """ Remove instances that have an area ratio smaller than `bbox_area_threshold`.
497
+ """
498
+ filtered_sample = []
499
+ for instance in sample:
500
+ min_h, min_w, max_h, max_w = self.get_bbox(instance)
501
+ bbox_area_ratio = (max_h - min_h) * (max_w - min_w) / (image_size[0] * image_size[1])
502
+ if bbox_area_ratio < self.bbox_area_threshold:
503
+ continue
504
+ filtered_sample.append(instance)
505
+ return filtered_sample
506
+
507
+ def hflip(self, sample, width):
508
+ """ Horizontal flipping the instances in a sample.
509
+ """
510
+ flipped_sample = []
511
+ for instance in sample:
512
+ instance_copy = instance.copy()
513
+ instance_copy[:, :, 0] = width - instance_copy[:, :, 0]
514
+ flipped_sample.append(instance_copy)
515
+ return flipped_sample
516
+
517
+ def get_binary_masks(self, sample):
518
+ """ Creates the binary mask of each instance in the sample.
519
+ """
520
+ if self.max_instance_n is None:
521
+ max_instance_n = len(sample)
522
+ else:
523
+ max_instance_n = self.max_instance_n
524
+ masks = np.zeros((max_instance_n, self.mask_size, self.mask_size))
525
+ bboxes = np.zeros((max_instance_n, 4))
526
+ valid = np.full(max_instance_n, False)
527
+ for i, instance in enumerate(sample):
528
+ bbox = self.get_bbox(instance)
529
+ min_h, min_w, max_h, max_w = bbox
530
+ instance_copy = instance.copy()
531
+ mask = np.zeros((self.mask_size, self.mask_size), dtype=np.uint8)
532
+ instance_copy[:,:,0] = (instance_copy[:,:,0] - min_w) / (max_w - min_w) * self.mask_size
533
+ instance_copy[:,:,1] = (instance_copy[:,:,1] - min_h) / (max_h - min_h) * self.mask_size
534
+ cv2.drawContours(mask, [instance_copy], 0, (255), thickness=cv2.FILLED)
535
+ masks[i] = mask / 255.0
536
+ bboxes[i] = np.array(bbox)
537
+ valid[i] = True
538
+ return masks, bboxes, valid
539
+
540
+ def load(self, path):
541
+ sample = np.load(path, allow_pickle=True)
542
+ return sample
543
+
544
+ def preprocess(self, sample):
545
+ if self.max_instance_n is None or len(sample) <= self.max_instance_n:
546
+ indecies = np.arange(len(sample))
547
+ else:
548
+ indecies = np.random.choice(len(sample), size=self.max_instance_n, replace=False)
549
+ return [p['points'] for i, p in enumerate(sample) if i in indecies]
550
+
551
+ def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
552
+ rand_aug_idx: Optional[int], resample_mode: str = None):
553
+ v = self.crop_sample(v, crop_coords)
554
+ _, _, h, w = crop_coords
555
+ v = self.resize_sample(v, (h, w), target_size)
556
+ v = self.remove_tiny_instances(v, target_size)
557
+ if flip:
558
+ v = self.hflip(v, target_size[0])
559
+ return v
560
+
561
+ def postprocess(self, sample):
562
+ sample, bboxes, valid = self.get_binary_masks(sample)
563
+ return {
564
+ 'instance': torch.from_numpy(sample).to(torch.float32),
565
+ 'bbox': torch.from_numpy(bboxes).to(torch.float32),
566
+ 'valid': torch.from_numpy(valid)
567
+ }
568
+
569
+
570
+ class MaskTransform(ImageTransform):
571
+
572
+ def __init__(self, mask_pool_size=1):
573
+ assert isinstance(mask_pool_size, int)
574
+ self.mask_pool_size = mask_pool_size # Use to expand masks
575
+
576
+ def mask_to_tensor(self, img):
577
+ mask = TF.to_tensor(img)
578
+ if self.mask_pool_size > 1:
579
+ mask = reduce(mask, 'c (h1 h2) (w1 w2) -> c h1 w1', 'min', h2=self.mask_pool_size, w2=self.mask_pool_size)
580
+ mask = repeat(mask, 'c h1 w1 -> c (h1 h2) (w1 w2)', h2=self.mask_pool_size, w2=self.mask_pool_size)
581
+ return (mask == 1.0)
582
+
583
+ def load(self, path):
584
+ sample = self.pil_loader(path)
585
+ return sample
586
+
587
+ def preprocess(self, sample):
588
+ return sample
589
+
590
+ def image_augment(self, img, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
591
+ rand_aug_idx: Optional[int], resample_mode: str = None):
592
+ # Override resampling mode to 'nearest' for masks
593
+ img = self.image_crop_and_resize(img, crop_coords, target_size, resample_mode='nearest')
594
+ img = self.image_hflip(img, flip)
595
+ return img
596
+
597
+ def postprocess(self, sample):
598
+ sample = self.mask_to_tensor(sample)
599
+ return sample
600
+
601
+
602
+ class TokTransform(AbstractTransform):
603
+
604
+ def __init__(self):
605
+ pass
606
+
607
+ def load(self, path):
608
+ sample = np.load(path).astype(int)
609
+ return sample
610
+
611
+ def preprocess(self, sample):
612
+ return sample
613
+
614
+ def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
615
+ rand_aug_idx: Optional[int], resample_mode: str = None):
616
+ if rand_aug_idx is None:
617
+ raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used")
618
+ v = torch.tensor(v[rand_aug_idx])
619
+ return v
620
+
621
+ def postprocess(self, sample):
622
+ return sample
623
+
624
+
625
+ class DetectionTransform(AbstractTransform):
626
+
627
+ def __init__(self, det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0, return_raw=False):
628
+ self.det_threshold = det_threshold
629
+ self.det_max_instances = det_max_instances
630
+ self.coord_bins = coord_bins
631
+ self.min_visibility = min_visibility
632
+ self.return_raw = return_raw
633
+
634
+ if bbox_order == 'area':
635
+ self.bbox_order = self.order_bboxes_by_area
636
+ elif bbox_order == 'score':
637
+ self.bbox_order = self.order_bboxes_by_score
638
+ elif bbox_order == 'random':
639
+ self.bbox_order = self.shuffle_bboxes
640
+ else:
641
+ self.bbox_order = self.order_bboxes_by_dist_to_orig
642
+
643
+ @staticmethod
644
+ def order_bboxes_by_area(bboxes):
645
+ return sorted(bboxes, key=lambda x: (x[2] - x[0]) * (x[3] - x[1]), reverse=True)
646
+
647
+ @staticmethod
648
+ def order_bboxes_by_dist_to_orig(bboxes):
649
+ return sorted(bboxes, key=lambda x: x[0] ** 2 + x[1] ** 2)
650
+
651
+ @staticmethod
652
+ def order_bboxes_by_score(bboxes):
653
+ return sorted(bboxes, key=lambda x: x[5], reverse=True)
654
+
655
+ @staticmethod
656
+ def shuffle_bboxes(bboxes):
657
+ return sorted(bboxes, key=lambda x: random.random())
658
+
659
+ def convert_detection_instance(self, instances):
660
+ """Convert instances dict to list of lists where each list takes the form:
661
+ [xmin, ymin, xmax, ymax, class_name, score]
662
+ """
663
+
664
+ instances = [inst['boxes'] + [inst['class_name'], inst['score']] for inst in instances if inst['score'] >= self.det_threshold]
665
+ return instances
666
+
667
+ def bboxes_hflip(self, bboxes: List[Tuple], image_size: Tuple, flip: bool):
668
+ image_height, image_width = image_size
669
+ if flip:
670
+ bboxes = [tuple(A.bbox_hflip(bbox[:4], rows=image_height, cols=image_width)) + tuple(bbox[4:])
671
+ for bbox in bboxes]
672
+
673
+ return bboxes
674
+
675
+ def bboxes_crop_and_resize(self, bboxes: List[Tuple], crop_coords: Tuple, orig_size: Tuple):
676
+ """Crop and resize bounding boxes
677
+
678
+ Args:
679
+ bboxes: Bounding boxes to crop and resize
680
+ crop_coords: Coordinates of the crop (top, left, h, w)
681
+ orig_size: Size of the original image
682
+
683
+ Returns:
684
+ Cropped and resized bounding boxes
685
+ """
686
+ orig_height, orig_width = orig_size
687
+ top, left, h, w = crop_coords
688
+ xmin, ymin, xmax, ymax = left, top, left + w, top + h
689
+ bboxes = [tuple(A.bbox_crop(bbox[:4], x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height,
690
+ cols=orig_width)) + tuple(bbox[4:])
691
+ for bbox in bboxes]
692
+ bboxes = A.core.bbox_utils.filter_bboxes(bboxes, rows=h, cols=w, min_visibility=self.min_visibility)
693
+ # No need to resize, bounding boxes in albumentations format are scale invariant
694
+
695
+ return bboxes
696
+
697
+ def order_and_filter_bboxes(self, bboxes):
698
+ if self.det_max_instances is not None and len(bboxes) > self.det_max_instances:
699
+ bboxes = self.order_bboxes_by_score(bboxes)[:self.det_max_instances]
700
+
701
+ return self.bbox_order(bboxes)
702
+
703
+ def convert_bboxes_to_string(self, bboxes: List[Tuple]):
704
+ """Convert bounding boxes to a string.
705
+ xmin, ymin, xmax, ymax are mapped to v0, v1, v2, v3 special tokens.
706
+
707
+ Args:
708
+ bboxes: Bounding boxes
709
+
710
+ Returns:
711
+ String representation of the bounding boxes
712
+ """
713
+ # Remove score, quantize coordinates
714
+ bins = self.coord_bins
715
+
716
+ bboxes = [
717
+ [
718
+ f"v0={round(xmin * (bins - 1))}",
719
+ f"v1={round(ymin * (bins - 1))}",
720
+ f"v2={round(xmax * (bins - 1))}",
721
+ f"v3={round(ymax * (bins - 1))}",
722
+ cls,
723
+ ]
724
+ for (xmin, ymin, xmax, ymax, cls, score) in bboxes
725
+ ]
726
+ # Convert each bounding box to a string
727
+ bboxes = [' '.join(b) for b in bboxes]
728
+ # Convert the list to a str
729
+ return ' '.join(bboxes)
730
+
731
+ def load(self, path):
732
+ with open(path, 'r') as f:
733
+ sample = json.load(f)
734
+
735
+ return sample
736
+
737
+ def preprocess(self, sample):
738
+ instances = sample['instances']
739
+ return self.convert_detection_instance(instances)
740
+
741
+ def image_augment(self, bboxes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
742
+ rand_aug_idx=None, resample_mode: str = None):
743
+ bboxes = self.bboxes_crop_and_resize(bboxes, crop_coords, orig_size)
744
+ bboxes = self.bboxes_hflip(bboxes, target_size, flip)
745
+ bboxes = self.order_and_filter_bboxes(bboxes)
746
+ return bboxes
747
+
748
+ def postprocess(self, bboxes):
749
+ if self.return_raw:
750
+ return bboxes
751
+ bboxes = self.convert_bboxes_to_string(bboxes)
752
+ return bboxes
753
+
754
+
755
+ class CaptionTransform(AbstractTransform):
756
+
757
+ def __init__(self, aligned_captions=True, no_aug=False):
758
+ self.aligned_captions = aligned_captions
759
+ self.no_aug = no_aug
760
+
761
+ def load(self, path):
762
+ # Caption can either be stored as .txt or .json.gz (in which case it's a list of dicts)
763
+ if path.endswith('.txt'):
764
+ sample = Path(path).read_text()
765
+ elif path.endswith('.json'):
766
+ with open(path, 'r') as f:
767
+ sample = json.load(f)
768
+ elif path.endswith('.json.gz'):
769
+ with gzip.open(path, 'rb') as f:
770
+ sample = json.load(f)
771
+ return sample
772
+
773
+ def preprocess(self, sample):
774
+ return sample
775
+
776
+ def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
777
+ rand_aug_idx: Optional[int], resample_mode: str = None):
778
+
779
+ if isinstance(val, list) or isinstance(val, tuple):
780
+ if self.aligned_captions:
781
+ val = val[0] if rand_aug_idx is None else val[rand_aug_idx]
782
+ else:
783
+ val = random.choice(val) if not self.no_aug else val[0]
784
+
785
+ if isinstance(val, dict):
786
+ # If each caption is saved as a dict, extract the string
787
+ val = val["caption"]
788
+ assert isinstance(val, str)
789
+
790
+ return val
791
+
792
+ def postprocess(self, sample):
793
+ return sample
794
+
795
+
796
+ class CaptionEmbTransform(AbstractTransform):
797
+
798
+ def __init__(self, aligned_captions=True, no_aug=False):
799
+ self.aligned_captions = aligned_captions
800
+ self.no_aug = no_aug
801
+
802
+ def load(self, path):
803
+ if path.endswith('.npz'):
804
+ sample = np.load(path)
805
+ sample = {'emb': sample['emb'], 'mask_valid': sample['mask_valid']}
806
+ else:
807
+ raise ValueError(f"Invalid file format for caption embedding: {path}")
808
+ return sample
809
+
810
+ def preprocess(self, sample):
811
+ return sample
812
+
813
+ def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
814
+ rand_aug_idx: Optional[int], resample_mode: str = None):
815
+
816
+ emb = val['emb']
817
+ mask_valid = val['mask_valid'].astype(bool)
818
+ num_sequences = emb.shape[0]
819
+
820
+ if num_sequences > 1:
821
+ if self.aligned_captions:
822
+ if rand_aug_idx is None:
823
+ emb, mask_valid = emb[0], mask_valid[0]
824
+ else:
825
+ emb, mask_valid = emb[rand_aug_idx], mask_valid[rand_aug_idx]
826
+ else:
827
+ if self.no_aug:
828
+ emb, mask_valid = emb[0], mask_valid[0]
829
+ else:
830
+ rand_idx = random.randint(0, num_sequences - 1)
831
+ emb, mask_valid = emb[rand_idx], mask_valid[rand_idx]
832
+ else:
833
+ emb, mask_valid = emb[0], mask_valid[0]
834
+
835
+ emb = emb[mask_valid] # Keep only valid embeddings
836
+
837
+ return emb
838
+
839
+ def postprocess(self, sample):
840
+ return torch.tensor(sample)
841
+
842
+
843
+ class MetadataTransform(AbstractTransform):
844
+
845
+ def __init__(self,
846
+ special_vmin: int = 0,
847
+ special_vmax: int = 999,
848
+ shuffle: bool = True,
849
+ random_trunc: bool = False,
850
+ return_chunks: bool = True,
851
+ return_raw: bool = False,
852
+ image_dim_bin_size: int = 32,):
853
+ """Metadata transform that takes in a metadata dictionary and converts
854
+ it into a string, or list of strings (for chunked span masking).
855
+ Uses special tokens v1 to denote metadata types, and v0 for their values.
856
+
857
+ Args:
858
+ special_vmin: Minimum value for special tokens
859
+ special_vmax: Maximum value for special tokens
860
+ shuffle: Whether to shuffle the metadata order
861
+ random_trunc: Whether to randomly truncate the returned metadata
862
+ return_chunks: Whether to return a list of strings (for chunked span masking),
863
+ or a single string with all metadata concatenated
864
+ return_raw: Whether to return the raw metadata dictionary
865
+ """
866
+ self.special_vmin = special_vmin
867
+ self.special_vmax = special_vmax
868
+ self.shuffle = shuffle
869
+ self.random_trunc = random_trunc
870
+ self.return_chunks = return_chunks
871
+ self.return_raw = return_raw
872
+ self.image_dim_bin_size = image_dim_bin_size
873
+
874
+ # Explicit map to make sure that additional entries do not change existing IDs
875
+ # TODO: Make this work with other text tokenizers
876
+ self.metadata_id_map = {
877
+ 'original_width': 'v1=0',
878
+ 'original_height': 'v1=1',
879
+ 'caption_n_chars': 'v1=2',
880
+ 'caption_n_words': 'v1=3',
881
+ 'caption_n_sentences': 'v1=4',
882
+ 'n_humans': 'v1=5',
883
+ 'n_sam_instances': 'v1=6',
884
+ 'n_coco_instances': 'v1=7',
885
+ 'coco_instance_diversity': 'v1=8',
886
+ 'colorfulness': 'v1=9',
887
+ 'brightness': 'v1=10',
888
+ 'contrast': 'v1=11',
889
+ 'saturation': 'v1=12',
890
+ 'entropy': 'v1=13',
891
+ 'walkability': 'v1=14',
892
+ 'objectness': 'v1=15',
893
+ 'semantic_diversity': 'v1=16',
894
+ 'geometric_complexity': 'v1=17',
895
+ 'occlusion_score': 'v1=18',
896
+ 'watermark_score': 'v1=19',
897
+ 'aesthetic_score': 'v1=20',
898
+ }
899
+ self.id_metadata_map = {v: k for k, v in self.metadata_id_map.items()}
900
+
901
+ # Image-dimension modalities are binned into 32 bins
902
+ self.image_dim_modalities = ['original_height', 'original_width']
903
+
904
+ # Integer modalities that don't undergo any scaling (except for truncation)
905
+ self.metadata_int_modalities = [
906
+ 'caption_n_chars', 'caption_n_words', 'caption_n_sentences',
907
+ 'n_humans', 'n_sam_instances', 'n_coco_instances',
908
+ 'coco_instance_diversity', 'semantic_diversity',
909
+ ]
910
+
911
+ # Bin boundaries for manually defined metadata modalities.
912
+ # Lowest and highest bin boundaries are implicitly set to -inf and +inf
913
+ self.metadata_manual_bins = {
914
+ 'watermark_score': [0.5],
915
+ 'aesthetic_score': [4.5, 5.5],
916
+ }
917
+
918
+ # All other float or integer modalities that are binned into a defined number of bins
919
+ # Dictionary entries are (vmin, vmax, num_bins)
920
+ self.metadata_min_max_bins = {
921
+ 'colorfulness': (0, 150, 50),
922
+ 'brightness': (0, 255, 50),
923
+ 'contrast': (0, 127, 50),
924
+ 'saturation': (0, 255, 50),
925
+ 'entropy': (0, 10, 50),
926
+ 'walkability': (0, 1, 50),
927
+ 'objectness': (0, 1, 50),
928
+ 'geometric_complexity': (0, 0.75, 50),
929
+ 'occlusion_score': (0, 0.25, 50),
930
+ }
931
+
932
+ def image_dim_to_string(self, metadata, key, bin_size=32):
933
+ value = metadata[key] // bin_size
934
+ value = max(self.special_vmin, min(value, self.special_vmax))
935
+ return f"{self.metadata_id_map[key]} v0={value}"
936
+
937
+ def int_metadata_to_string(self, metadata, key):
938
+ value = max(self.special_vmin, min(metadata[key], self.special_vmax))
939
+ return f"{self.metadata_id_map[key]} v0={value}"
940
+
941
+ def float_metadata_to_string(self, metadata, key, vmin, vmax, bins):
942
+ value = max(vmin, min(metadata[key], vmax))
943
+ value = (value - vmin) / (vmax - vmin)
944
+ value = int(value * (bins-1))
945
+ return f"{self.metadata_id_map[key]} v0={value}"
946
+
947
+ def manual_bin_metadata_to_string(self, metadata, key):
948
+ value = metadata[key]
949
+ bin_idx = 0
950
+ for bin_value in self.metadata_manual_bins[key]:
951
+ if value < bin_value:
952
+ break
953
+ bin_idx += 1
954
+ return f"{self.metadata_id_map[key]} v0={bin_idx}"
955
+
956
+ def metadata_to_string(self, metadata, keys: List[str] = None):
957
+ keys = list(metadata.keys()) if keys is None else keys
958
+
959
+ if self.shuffle:
960
+ # Randomly shuffle
961
+ random.shuffle(keys)
962
+ if self.random_trunc:
963
+ # Randomly truncate
964
+ keys = keys[:random.randint(1,len(keys))]
965
+
966
+ metadata_strings = []
967
+
968
+ for key in keys:
969
+ if key in self.image_dim_modalities:
970
+ # Image dimension modalities
971
+ metadata_str = self.image_dim_to_string(metadata, key, bin_size=self.image_dim_bin_size)
972
+ elif key in self.metadata_int_modalities:
973
+ # Integer modalities that don't undergo any scaling
974
+ metadata_str = self.int_metadata_to_string(metadata, key)
975
+ elif key in self.metadata_manual_bins:
976
+ # Metadata modalities for which bin boundaries are manually defined
977
+ metadata_str = self.manual_bin_metadata_to_string(metadata, key)
978
+ else:
979
+ # All other modalities
980
+ vmin, vmax, bins = self.metadata_min_max_bins[key]
981
+ metadata_str = self.float_metadata_to_string(metadata, key, vmin, vmax, bins)
982
+
983
+ metadata_strings.append(metadata_str)
984
+
985
+ if self.return_chunks:
986
+ return metadata_strings
987
+ else:
988
+ return ' '.join(metadata_strings)
989
+
990
+ def load(self, path):
991
+ with open(path, 'r') as f:
992
+ sample = json.load(f)
993
+ return sample
994
+
995
+ def preprocess(self, sample):
996
+ return sample
997
+
998
+ def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
999
+ rand_aug_idx=None, resample_mode: str = None):
1000
+ return val
1001
+
1002
+ def postprocess(self, metadata):
1003
+ if self.return_raw:
1004
+ return metadata
1005
+ metadata_str = self.metadata_to_string(metadata)
1006
+ return metadata_str
1007
+
1008
+
1009
+ class HumanPoseTransform(AbstractTransform):
1010
+
1011
+ def __init__(self, coord_bins=1000, only_pose=False, return_raw=False):
1012
+ self.coord_bins = coord_bins
1013
+ self.return_raw = return_raw
1014
+ self.only_pose = only_pose
1015
+
1016
+ def convert_humanpose_instance(self, instances, only_pose=False):
1017
+ """Convert instances dict to list of lists where each list takes the form:
1018
+ [human, xmin xmax ymin ymax global val1 val2 ... val10 pose val1 val2 ... val 207 shape val1 val2 ... val10 camera val1 val2 val3 val4]
1019
+ Like for bounding boxes, xmin, ymin, xmax, and ymax map to v0, v1, v2, and v3 respectively.
1020
+ """
1021
+ if only_pose: # used for tokenizer training for pose
1022
+ if len(instances) == 0:
1023
+ return torch.zeros(207)
1024
+ else:
1025
+ return torch.from_numpy(np.array(instances['pred_smpl_params']['body_pose'][0]).flatten()).float()
1026
+ if len(instances) == 0: #empty, i.e. there are no humans
1027
+ return 'none'
1028
+
1029
+ for k in instances:
1030
+ if k!='pred_smpl_params':
1031
+ instances[k] = torch.from_numpy(np.array(instances[k]))
1032
+
1033
+ smpl_params = (instances['pred_smpl_params'])
1034
+
1035
+ for k in smpl_params:
1036
+ smpl_params[k] = torch.from_numpy(np.array(smpl_params[k]))
1037
+
1038
+ total_num_instances = len(instances['bbox_xyxy'])
1039
+ instances_converted = []
1040
+ for ii in range(total_num_instances):
1041
+ instances_converted.append(['human'] + (np.array(instances['bbox_xyxy'][ii]).flatten().tolist()) + ['global'] + (np.array(instances['pred_smpl_params']['global_orient'][ii]).flatten().tolist()) + ['pose'] + (instances['pose_tokenized'][ii].flatten().tolist()) + ['shape'] + (instances['pred_smpl_params']['betas'][ii].flatten().tolist()) + ['camera'] + (instances['pred_cam'][ii].flatten().tolist()))
1042
+ return instances_converted
1043
+
1044
+ def humanposes_crop_and_resize(self, humanposes: List[Tuple], crop_coords: Tuple, orig_size: Tuple,):
1045
+ """Crop and resize human poses (and their bounding boxes)
1046
+ """
1047
+ orig_height, orig_width = orig_size
1048
+ top, left, h, w = crop_coords
1049
+
1050
+ humanposes_converted_resized = []
1051
+ for instance in humanposes:
1052
+ bbox_curr = instance[1:5]
1053
+ bbox_curr = np.array(bbox_curr)
1054
+ bbox_curr[0::2] = bbox_curr[0::2] / orig_width
1055
+ bbox_curr[1::2] = bbox_curr[1::2] / orig_height
1056
+
1057
+ xmin, ymin, xmax, ymax = left, top, left + w, top + h
1058
+ bbox_curr = A.bbox_crop(bbox_curr, x_min=xmin, y_min=ymin, x_max=xmax, y_max=ymax, rows=orig_height,
1059
+ cols=orig_width)
1060
+ bbox_curr = np.array(bbox_curr)
1061
+ if np.all(bbox_curr[1::2]<0) or np.all(bbox_curr[0::2]<0): #bbox is out of range, remove it
1062
+ continue
1063
+ if np.all(bbox_curr[1::2]>1.0) or np.all(bbox_curr[0::2]>1.0): #bbox is out of range, remove it
1064
+ continue
1065
+ bbox_curr = np.clip(bbox_curr, a_min=0, a_max=1.)
1066
+
1067
+ instance[1:5] = bbox_curr
1068
+ humanposes_converted_resized.append(instance)
1069
+
1070
+ # now return all instances, or none if there is no instance
1071
+ if len(humanposes_converted_resized)>0:
1072
+ pass
1073
+ else: #no valid masks remains
1074
+ return 'none'
1075
+
1076
+ humanpose_returned = humanposes_converted_resized
1077
+
1078
+ return humanpose_returned
1079
+
1080
+ def convert_humanposes_to_string(self, all_humanposes: List[Tuple]):
1081
+ """Convert humanposes to a string
1082
+ range of global orientation: [-1, 1]
1083
+ range of object pose: [-1, 1]
1084
+ range of shape (betas): [-3, 3]
1085
+ range of camera: [-1, 19]
1086
+ """
1087
+ bins = self.coord_bins
1088
+
1089
+ instance_final_all = ''
1090
+
1091
+ for humanposes in all_humanposes:
1092
+ human = humanposes[0]
1093
+ bboxes = humanposes[1:5]
1094
+ glob = humanposes[5]
1095
+ global_orient = np.array(humanposes[6:15])
1096
+ pose = humanposes[15]
1097
+ pose_params = np.array(humanposes[16:24])
1098
+ shape = humanposes[24]
1099
+ shape_params = np.array(humanposes[25:35])
1100
+ camera = humanposes[35]
1101
+ camera_params = np.clip(np.array(humanposes[36:]), a_min=-1., a_max=19.)
1102
+
1103
+ bboxes_new = [
1104
+ f"v0={round(bboxes[0] * (bins - 1))}",
1105
+ f"v1={round(bboxes[1] * (bins - 1))}",
1106
+ f"v2={round(bboxes[2] * (bins - 1))}",
1107
+ f"v3={round(bboxes[3] * (bins - 1))}"]
1108
+
1109
+ global_orient = 499.5*global_orient
1110
+ global_orient_new = []
1111
+ for ii in range(len(global_orient)):
1112
+ global_orient_curr = f"v0={round(global_orient[ii]+499.5)}"
1113
+ global_orient_new.append(global_orient_curr)
1114
+
1115
+ pose_params_new = []
1116
+ for ii in range(len(pose_params)):
1117
+ if pose_params[ii]<512:
1118
+ pose_params_curr = f"v0={round(pose_params[ii])}"
1119
+ else:
1120
+ pose_params_curr = f"v1={round(pose_params[ii] - 512)}"
1121
+ pose_params_new.append(pose_params_curr)
1122
+
1123
+ shape_params = 166.5*shape_params
1124
+ shape_params_new = []
1125
+ for ii in range(len(shape_params)):
1126
+ shape_params_curr = f"v0={round(shape_params[ii]+499.5)}"
1127
+ shape_params_new.append(shape_params_curr)
1128
+
1129
+ camera_params = 49.95*camera_params
1130
+ camera_params_new = []
1131
+ for ii in range(len(camera_params)):
1132
+ camera_params_curr = f"v0={round(camera_params[ii]+49.95)}"
1133
+ camera_params_new.append(camera_params_curr)
1134
+
1135
+ #randomly shuffle everything except bbox part of the sequence
1136
+ all_strings = [[pose]+pose_params_new, [glob] + global_orient_new, [camera] + camera_params_new, [shape] + shape_params_new ]
1137
+ rand_perm = torch.randperm(4)
1138
+ instance_final = [human] + bboxes_new + all_strings[rand_perm[0]] + all_strings[rand_perm[1]] + all_strings[rand_perm[2]] + all_strings[rand_perm[3]]
1139
+
1140
+
1141
+ instance_final = ', '.join(instance_final)
1142
+ instance_final = instance_final.replace(",", "")
1143
+ instance_final_all = instance_final_all + instance_final + ' '
1144
+
1145
+ return instance_final_all
1146
+
1147
+ def load(self, path):
1148
+ with open(path, 'r') as f:
1149
+ sample = json.load(f)
1150
+
1151
+ return sample
1152
+
1153
+ def preprocess(self, sample):
1154
+ instances = sample
1155
+ instances = self.convert_humanpose_instance(instances, only_pose=self.only_pose)
1156
+ return instances
1157
+
1158
+ def image_augment(self, humanposes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
1159
+ rand_aug_idx=None, resample_mode: str = None):
1160
+ if humanposes=='none' or self.only_pose:
1161
+ return humanposes
1162
+ humanposes = self.humanposes_crop_and_resize(humanposes, crop_coords, orig_size)
1163
+ return humanposes
1164
+
1165
+ def postprocess(self, humanposes):
1166
+ if humanposes=='none' or self.only_pose:
1167
+ return humanposes if not self.return_raw else []
1168
+ if self.return_raw:
1169
+ return humanposes
1170
+ humanposes = self.convert_humanposes_to_string(humanposes)
1171
+ return humanposes
1172
+
1173
+
1174
+ class ColorPaletteTransform(AbstractTransform):
1175
+
1176
+ def __init__(self, coord_bins=1000, return_raw=False):
1177
+ self.coord_bins = coord_bins
1178
+ self.return_raw = return_raw
1179
+
1180
+ def convert_palette_instance(self, instances):
1181
+ """Convert colors to v0= v0= ...
1182
+ """
1183
+ length = random.randint(1,7)
1184
+ instances_converted = np.array(instances[0][str(length)]).flatten().tolist()
1185
+ return instances_converted
1186
+
1187
+ def palette_hflip(self, palettes: List[Tuple], image_size: Tuple, flip: bool):
1188
+
1189
+ return palettes
1190
+
1191
+ def convert_palettes_to_string(self, all_palettes: List[Tuple]):
1192
+ """Convert palettes to a string
1193
+ """
1194
+
1195
+ colors = []
1196
+ len_palettes = len(all_palettes)
1197
+ colors.append(f"v1={round(len_palettes/3)}") # start with the length of the color palette to avoid confusion
1198
+ for ii in range(len(all_palettes)):
1199
+ color_new = f"v0={round(all_palettes[ii])}"
1200
+ colors.append(color_new)
1201
+
1202
+ instance_final_all = colors
1203
+ instance_final_all = ', '.join(instance_final_all)
1204
+ instance_final_all = instance_final_all.replace(",", "")
1205
+
1206
+ return instance_final_all
1207
+
1208
+ def load(self, path):
1209
+ with open(path, 'r') as f:
1210
+ sample = json.load(f)
1211
+ return sample
1212
+
1213
+ def preprocess(self, sample):
1214
+ if self.return_raw:
1215
+ return sample
1216
+ instances = sample
1217
+ instances = self.convert_palette_instance(instances)
1218
+ return instances
1219
+
1220
+ def image_augment(self, palettes: List[Tuple], crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
1221
+ rand_aug_idx=None, resample_mode: str = None):
1222
+ return palettes
1223
+
1224
+ def postprocess(self, palettes):
1225
+ if self.return_raw:
1226
+ return palettes
1227
+ palettes = self.convert_palettes_to_string(palettes)
1228
+ return palettes
1229
+
1230
+
1231
+ class SAMInstanceTokTransform(AbstractTransform):
1232
+
1233
+ def __init__(self, image_size=224, points_per_side=7, point_order='random'):
1234
+ self.H, self.W = to_2tuple(image_size)
1235
+ self.points_per_h, self.points_per_w = to_2tuple(points_per_side)
1236
+ assert point_order in ['random', 'grid']
1237
+ self.point_order = point_order
1238
+
1239
+ def get_query_points(self):
1240
+ if self.point_order == 'grid':
1241
+ # Create and cache grid query points
1242
+ if not hasattr(self, 'grid_query_points'):
1243
+ y, x = np.meshgrid(np.linspace(0, self.H, self.points_per_h + 2)[1:-1], np.linspace(0, self.W, self.points_per_w + 2)[1:-1])
1244
+ grid = np.stack((x, y), axis=2).astype(np.int32)
1245
+ self.grid_query_points = grid.reshape(-1, 2)
1246
+ return self.grid_query_points
1247
+ elif self.point_order == 'random':
1248
+ # Randomly sample query points
1249
+ y = np.random.randint(0, self.H, self.points_per_h)
1250
+ x = np.random.randint(0, self.W, self.points_per_w)
1251
+ return np.concatenate((x[:,None], y[:,None]), axis=1)
1252
+ else:
1253
+ raise ValueError(f"Query point order mode {self.point_order} is not supported.")
1254
+
1255
+ def get_target_tokens(self, sample, query_points):
1256
+ instances_coords = [coords[0] for coords in sample['points']]
1257
+ tokens = sample['token_ids']
1258
+ bboxes = sample['bbox']
1259
+
1260
+ instance_tokens_per_qpoint = dict()
1261
+ for point in query_points:
1262
+ point = (int(point[0].item()), int(point[1].item()))
1263
+ instance_tokens_per_qpoint[point] = []
1264
+ for i, (coords, tok, bbox) in enumerate(zip(instances_coords, tokens, bboxes)):
1265
+ # Calculate the distance from the query point to the instance
1266
+ distance = cv2.pointPolygonTest(coords, point, measureDist=True)
1267
+ # If the query point is inside the instance, add its corresponding token
1268
+ if distance >= 0:
1269
+ instance_tokens_per_qpoint[point].append((tok, bbox))
1270
+
1271
+ return instance_tokens_per_qpoint
1272
+
1273
+ def convert_target_tokens_to_string(self, target_tokens):
1274
+ result_text = []
1275
+ query_points = list(target_tokens.keys())
1276
+ # Randomly shuffle query points order (mainly for grid order)
1277
+ random.shuffle(query_points)
1278
+ for point in query_points:
1279
+
1280
+ # Add query point coordinates to the string
1281
+ result_text.append('point')
1282
+ result_text.append(f'v0={point[1]}')
1283
+ result_text.append(f'v1={point[0]}')
1284
+
1285
+ # Randomly shuffle the order of instance tokens per query point
1286
+ random.shuffle(target_tokens[point])
1287
+ if len(target_tokens[point]) == 0:
1288
+ # If no instances tokens are found, add 'none' to the string
1289
+ result_text.append('none')
1290
+ else:
1291
+ for tok, bbox in target_tokens[point]:
1292
+ result_text.append(f'polygon')
1293
+
1294
+ # Add bounding box coordinates to the string
1295
+ ymin, xmin, ymax, xmax = bbox.astype(np.int32)
1296
+ result_text.extend([
1297
+ f'v0={xmin}',
1298
+ f'v1={ymin}',
1299
+ f'v2={xmax}',
1300
+ f'v3={ymax}',
1301
+ ])
1302
+
1303
+ # Add instance tokens ids to the string
1304
+ for idx in tok.tolist():
1305
+ if idx < 512:
1306
+ result_text.append(f'v0={idx}')
1307
+ else:
1308
+ result_text.append(f'v1={idx - 512}')
1309
+
1310
+ return " ".join(result_text)
1311
+
1312
+ def load(self, path):
1313
+ sample = np.load(path, allow_pickle=True)
1314
+ return sample
1315
+
1316
+ def preprocess(self, sample):
1317
+ for s in sample:
1318
+ s['token_ids'] = s['token_ids'].astype(np.int32)
1319
+ return sample
1320
+
1321
+ def image_augment(self, v, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
1322
+ rand_aug_idx: Optional[int], resample_mode: str = None):
1323
+ if rand_aug_idx is None:
1324
+ raise ValueError("Crop settings / augmentation index are missing but a pre-tokenized modality is being used")
1325
+ v = v[rand_aug_idx]
1326
+ return v
1327
+
1328
+ def postprocess(self, sample):
1329
+ query_points = self.get_query_points()
1330
+ target_tokens = self.get_target_tokens(sample, query_points)
1331
+ final_string = self.convert_target_tokens_to_string(target_tokens)
1332
+ return final_string
1333
+
1334
+
1335
+ class CropSettingsTransform(AbstractTransform):
1336
+
1337
+ def load(self, path):
1338
+ sample = np.load(path)
1339
+ return sample
1340
+
1341
+ def preprocess(self, sample):
1342
+ raise NotImplementedError("CropSettingsTransform does not support preprocessing")
1343
+
1344
+ def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
1345
+ rand_aug_idx: Optional[int], resample_mode: str = None):
1346
+ raise NotImplementedError("CropSettingsTransform is not meant to be used for image augmentation")
1347
+
1348
+ def postprocess(self, sample):
1349
+ raise NotImplementedError("CropSettingsTransform does not support postprocessing")
1350
+
1351
+
1352
+ class IdentityTransform(AbstractTransform):
1353
+
1354
+ def load(self, path):
1355
+ raise NotImplementedError("IdentityTransform does not support loading")
1356
+
1357
+ def preprocess(self, sample):
1358
+ return sample
1359
+
1360
+ def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
1361
+ rand_aug_idx: Optional[int], resample_mode: str = None):
1362
+ return val
1363
+
1364
+ def postprocess(self, sample):
1365
+ return sample
1366
+
1367
+
1368
+ class JSONTransform(AbstractTransform):
1369
+
1370
+ def load(self, path):
1371
+ if path.endswith('.json'):
1372
+ with open(path, 'r') as f:
1373
+ sample = json.load(f)
1374
+ elif path.endswith('.json.gz'):
1375
+ with gzip.open(path, 'rb') as f:
1376
+ sample = json.load(f)
1377
+ return sample
1378
+
1379
+ def preprocess(self, sample):
1380
+ return sample
1381
+
1382
+ def image_augment(self, val, crop_coords: Tuple, flip: bool, orig_size: Tuple, target_size: Tuple,
1383
+ rand_aug_idx: Optional[int], resample_mode: str = None):
1384
+ return val
1385
+
1386
+ def postprocess(self, sample):
1387
+ return sample
fourm/data/multimodal_dataset_folder.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import os.path
16
+ import pickle
17
+ import random
18
+ from copy import deepcopy
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, cast
20
+
21
+ import numpy as np
22
+ from torchvision.datasets.vision import VisionDataset
23
+
24
+ from fourm.data.modality_transforms import AbstractTransform, get_transform_key
25
+
26
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp', '.jpx', '.npy', '.npz')
27
+
28
+ UNIFIED_EXTENSIONS = IMG_EXTENSIONS + ('.json', '.txt', '.json.gz')
29
+
30
+
31
+ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
32
+ """Checks if a file is an allowed extension.
33
+
34
+ Args:
35
+ filename (string): path to a file
36
+ extensions (tuple of strings): extensions to consider (lowercase)
37
+
38
+ Returns:
39
+ bool: True if the filename ends with one of given extensions
40
+ """
41
+ return filename.lower().endswith(extensions)
42
+
43
+
44
+ def is_image_file(filename: str) -> bool:
45
+ """Checks if a file is an allowed image extension.
46
+
47
+ Args:
48
+ filename (string): path to a file
49
+
50
+ Returns:
51
+ bool: True if the filename ends with a known image extension
52
+ """
53
+ return has_file_allowed_extension(filename, IMG_EXTENSIONS)
54
+
55
+
56
+ def make_dataset(
57
+ directory: str,
58
+ class_to_idx: Dict[str, int],
59
+ extensions: Optional[Tuple[str, ...]] = None,
60
+ is_valid_file: Optional[Callable[[str], bool]] = None,
61
+ cache_path: Optional[str] = None,
62
+ ) -> List[Tuple[str, int]]:
63
+ if cache_path is not None and os.path.exists(cache_path):
64
+ # Load cached file paths from disk if it exists
65
+ with open(cache_path, 'rb') as f:
66
+ return pickle.load(f)
67
+ instances = []
68
+ directory = os.path.expanduser(directory)
69
+ both_none = extensions is None and is_valid_file is None
70
+ both_something = extensions is not None and is_valid_file is not None
71
+ if both_none or both_something:
72
+ raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
73
+ if extensions is not None:
74
+ def is_valid_file(x: str) -> bool:
75
+ return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
76
+ is_valid_file = cast(Callable[[str], bool], is_valid_file)
77
+ for target_class in sorted(class_to_idx.keys()):
78
+ class_index = class_to_idx[target_class]
79
+ target_dir = os.path.join(directory, target_class)
80
+ if not os.path.isdir(target_dir):
81
+ continue
82
+ for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
83
+ for fname in sorted(fnames):
84
+ path = os.path.join(root, fname)
85
+ if is_valid_file(path):
86
+ item = path, class_index
87
+ instances.append(item)
88
+ if cache_path is not None:
89
+ # Cache all file paths s.t. setting up the dataloader is instant in the future
90
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
91
+ with open(cache_path, 'wb') as f:
92
+ pickle.dump(instances, f)
93
+ return instances
94
+
95
+
96
+ class DatasetFolder(VisionDataset):
97
+ """A generic data loader where the samples are arranged in this way: ::
98
+
99
+ root/class_x/xxx.ext
100
+ root/class_x/xxy.ext
101
+ root/class_x/xxz.ext
102
+
103
+ root/class_y/123.ext
104
+ root/class_y/nsdf3.ext
105
+ root/class_y/asd932_.ext
106
+
107
+ Args:
108
+ root (string): Root directory path.
109
+ loader (callable): A function to load a sample given its path.
110
+ extensions (tuple[string]): A list of allowed extensions.
111
+ both extensions and is_valid_file should not be passed.
112
+ transform (callable, optional): A function/transform that takes in
113
+ a sample and returns a transformed version.
114
+ E.g, ``transforms.RandomCrop`` for images.
115
+ target_transform (callable, optional): A function/transform that takes
116
+ in the target and transforms it.
117
+ is_valid_file (callable, optional): A function that takes path of a file
118
+ and check if the file is a valid file (used to check of corrupt logs)
119
+ both extensions and is_valid_file should not be passed.
120
+
121
+ Attributes:
122
+ classes (list): List of the class names sorted alphabetically.
123
+ class_to_idx (dict): Dict with items (class_name, class_index).
124
+ samples (list): List of (sample path, class_index) tuples
125
+ targets (list): The class_index value for each image in the dataset
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ root: str,
131
+ loader: Callable[[str], Any],
132
+ extensions: Optional[Tuple[str, ...]] = None,
133
+ transform: Optional[Callable] = None,
134
+ target_transform: Optional[Callable] = None,
135
+ is_valid_file: Optional[Callable[[str], bool]] = None,
136
+ ) -> None:
137
+ super(DatasetFolder, self).__init__(root, transform=transform,
138
+ target_transform=target_transform)
139
+ classes, class_to_idx = self._find_classes(self.root)
140
+ samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
141
+ if len(samples) == 0:
142
+ msg = "Found 0 logs in subfolders of: {}\n".format(self.root)
143
+ if extensions is not None:
144
+ msg += "Supported extensions are: {}".format(",".join(extensions))
145
+ raise RuntimeError(msg)
146
+
147
+ self.loader = loader
148
+ self.extensions = extensions
149
+
150
+ self.classes = classes
151
+ self.class_to_idx = class_to_idx
152
+ self.samples = samples
153
+ self.targets = [s[1] for s in samples]
154
+
155
+ def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
156
+ """
157
+ Finds the class folders in a dataset.
158
+
159
+ Args:
160
+ dir (string): Root directory path.
161
+
162
+ Returns:
163
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
164
+
165
+ Ensures:
166
+ No class is a subdirectory of another.
167
+ """
168
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()]
169
+ classes.sort()
170
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
171
+ return classes, class_to_idx
172
+
173
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
174
+ """
175
+ Args:
176
+ index (int): Index
177
+
178
+ Returns:
179
+ tuple: (sample, target) where target is class_index of the target class.
180
+ """
181
+ while True:
182
+ try:
183
+ path, target = self.samples[index]
184
+ sample = self.loader(path)
185
+ break
186
+ except Exception as e:
187
+ print(e)
188
+ index = random.randint(0, len(self.samples) - 1)
189
+
190
+ if self.transform is not None:
191
+ sample = self.transform(sample)
192
+ if self.target_transform is not None:
193
+ target = self.target_transform(target)
194
+
195
+ return sample, target
196
+
197
+ def __len__(self) -> int:
198
+ return len(self.samples)
199
+
200
+
201
+ class MultiModalDatasetFolder(VisionDataset):
202
+ """A generic multi-modal dataset loader where the samples are arranged in this way: ::
203
+
204
+ root/modality_a/class_x/xxx.ext
205
+ root/modality_a/class_y/xxy.ext
206
+ root/modality_a/class_z/xxz.ext
207
+
208
+ root/modality_b/class_x/xxx.ext
209
+ root/modality_b/class_y/xxy.ext
210
+ root/modality_b/class_z/xxz.ext
211
+
212
+ Args:
213
+ root (string): Root directory path.
214
+ modalities (list): List of modalities as strings
215
+ modality_paths (dict): Dict of paths to modalities
216
+ modality_transforms (dict): Dict of transforms for each modality
217
+ loader (callable): A function to load a sample given its path.
218
+ transform (callable, optional): A function/transform that takes in
219
+ a sample and returns a transformed version.
220
+ E.g, ``transforms.RandomCrop`` for images.
221
+ target_transform (callable, optional): A function/transform that takes
222
+ in the target and transforms it.
223
+ is_valid_file (callable, optional): A function that takes path of a file
224
+ and check if the file is a valid file (used to check of corrupt logs)
225
+ both extensions and is_valid_file should not be passed.
226
+ max_samples (int, optional): Maximum number of samples to load. If None, all samples are loaded.
227
+ pre_shuffle (bool, optional): Whether to shuffle the sample during the init.
228
+ return_paths (bool, optional): Whether to return the paths of the samples.
229
+ cache (bool, optional): Whether to cache the samples in memory. If True, the samples are loaded only once and then cached in memory.
230
+
231
+ Attributes:
232
+ classes (list): List of the class names sorted alphabetically.
233
+ class_to_idx (dict): Dict with items (class_name, class_index).
234
+ samples (list): List of (sample path, class_index) tuples
235
+ targets (list): The class_index value for each image in the dataset
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ root: str,
241
+ modalities: List[str],
242
+ modality_paths: Dict[str, str],
243
+ modality_transforms: Dict[str, AbstractTransform],
244
+ transform: Optional[Callable] = None,
245
+ target_transform: Optional[Callable] = None,
246
+ is_valid_file: Optional[Callable[[str], bool]] = None,
247
+ max_samples: Optional[int] = None,
248
+ pre_shuffle: bool = False,
249
+ cache: bool = False,
250
+ return_path: bool = False,
251
+ ) -> None:
252
+ super(MultiModalDatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)
253
+ self.modalities = modalities
254
+ # If modality_paths is not provided, use the default paths
255
+ self.modality_paths = modality_paths
256
+ for mod in self.modalities:
257
+ if mod not in self.modality_paths:
258
+ modality_paths[mod] = mod
259
+ self.modality_transforms = modality_transforms
260
+ self.return_path = return_path
261
+
262
+ classes, class_to_idx = self._find_classes(os.path.join(self.root, list(self.modality_paths.values())[0]))
263
+ extensions = UNIFIED_EXTENSIONS if is_valid_file is None else None
264
+
265
+ samples = {
266
+ mod: make_dataset(
267
+ os.path.join(self.root, f'{self.modality_paths[mod]}'),
268
+ class_to_idx,
269
+ extensions,
270
+ is_valid_file,
271
+ cache_path=os.path.join(self.root, 'dataloader_cache', f'{self.modality_paths[mod]}.pkl') if cache else None)
272
+ for mod in self.modalities
273
+ }
274
+
275
+ for mod, mod_samples in samples.items():
276
+ if len(mod_samples) == 0:
277
+ msg = "Found 0 logs in subfolders of: {}\n".format(os.path.join(self.root, f'{self.modality_paths[mod]}'))
278
+ if extensions is not None:
279
+ msg += "Supported extensions are: {}".format(",".join(extensions))
280
+ raise RuntimeError(msg)
281
+
282
+ self.extensions = extensions
283
+
284
+ self.classes = classes
285
+ self.class_to_idx = class_to_idx
286
+ self.samples = samples
287
+
288
+ # Select random subset of dataset if so specified
289
+ if isinstance(max_samples, int):
290
+ total_samples = len(list(self.samples.values())[0])
291
+ np.random.seed(0)
292
+ permutation = np.random.permutation(total_samples)
293
+ for task in samples:
294
+ self.samples[task] = [self.samples[task][i] for i in permutation][:max_samples]
295
+
296
+ if pre_shuffle:
297
+ total_samples = len(list(self.samples.values())[0])
298
+ np.random.seed(100)
299
+ permutation = np.random.permutation(total_samples)
300
+ for task in samples:
301
+ self.samples[task] = [self.samples[task][i] for i in permutation]
302
+
303
+ self.cache = {}
304
+ self.imgs = self.samples
305
+
306
+ def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
307
+ """
308
+ Finds the class folders in a dataset.
309
+
310
+ Args:
311
+ dir (string): Root directory path.
312
+
313
+ Returns:
314
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
315
+
316
+ Ensures:
317
+ No class is a subdirectory of another.
318
+ """
319
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()]
320
+ classes.sort()
321
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
322
+ return classes, class_to_idx
323
+
324
+ def get_class_and_file(self, path: str) -> Tuple[str, str]:
325
+ """ Extracts the class and file name from a path. """
326
+ class_id, file_name = path.split('/')[-2:]
327
+ file_name = file_name.split('.')[0]
328
+ return class_id, file_name
329
+
330
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
331
+ """
332
+ Args:
333
+ index (int): Index
334
+
335
+ Returns:
336
+ tuple: (sample, target) where target is class_index of the target class.
337
+ """
338
+ if index in self.cache:
339
+ sample_dict, target = deepcopy(self.cache[index])
340
+ else:
341
+ sample_dict = {}
342
+ for mod in self.modalities:
343
+ path, target = self.samples[mod][index]
344
+ sample = self.modality_transforms[get_transform_key(mod)].load(path)
345
+ sample_dict[mod] = sample
346
+ # self.cache[index] = deepcopy((sample_dict, target))
347
+
348
+ if self.transform is not None:
349
+ sample_dict = self.transform(sample_dict)
350
+ if self.target_transform is not None:
351
+ target = self.target_transform(target)
352
+
353
+ sample_dict['class_idx'] = target
354
+
355
+ if self.return_path and not index in self.cache:
356
+ class_id, file_name = self.get_class_and_file(path)
357
+ sample_dict['class_id'] = class_id
358
+ sample_dict['file_name'] = file_name
359
+
360
+ return sample_dict
361
+
362
+ def __len__(self) -> int:
363
+ return len(list(self.samples.values())[0])
fourm/data/pretrain_utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+
16
+ import torch
17
+ import yaml
18
+
19
+ import fourm.utils as utils
20
+
21
+ from fourm.data import (CenterCropImageAugmenter, EmptyAugmenter,
22
+ PreTokenizedImageAugmenter,RandomCropImageAugmenter, build_fm_pretraining_dataset,
23
+ build_huggingface_pretraining_dataloader,
24
+ build_wds_fm_pretraining_dataloader)
25
+ from fourm.data.modality_transforms import CaptionTransform
26
+ from fourm.data.modality_info import MODALITY_TRANSFORMS
27
+
28
+
29
+ def setup_sampling_mod_info(dataset_config, modality_info):
30
+ # Subset of modality info for each dataset
31
+
32
+ # Input and output modalities for one dataset
33
+ in_domains = sorted(dataset_config['in_domains'].split('-'))
34
+ out_domains = sorted(dataset_config['out_domains'].split('-'))
35
+ all_domains = sorted(list(set(in_domains) | set(out_domains)))
36
+
37
+ mod_info = copy.deepcopy(modality_info)
38
+ mod_info = {mod: mod_info[mod] for mod in all_domains}
39
+
40
+ # Dirichlet concentration parameter (Alpha)
41
+ if dataset_config.get('alphas_config', None) is None:
42
+ for mod in mod_info:
43
+ mod_info[mod]["input_alphas"] = [0.]
44
+ mod_info[mod]["target_alphas"] = [0.]
45
+
46
+ if 'input_alphas' in dataset_config:
47
+ input_alphas = dataset_config['input_alphas'].split('-')
48
+ if len(input_alphas) == 1:
49
+ input_alphas = [float(input_alphas[0])] * len(in_domains)
50
+ else:
51
+ input_alphas = [float(alpha) for alpha in input_alphas]
52
+ for mod, alpha in zip(in_domains, input_alphas):
53
+ mod_info[mod]['input_alphas'] = [alpha]
54
+
55
+ if 'target_alphas' in dataset_config:
56
+ target_alphas = dataset_config['target_alphas'].split('-')
57
+ if len(target_alphas) == 1:
58
+ target_alphas = [float(target_alphas[0])] * len(out_domains)
59
+ else:
60
+ target_alphas = [float(alpha) for alpha in target_alphas]
61
+ for mod, alpha in zip(out_domains, target_alphas):
62
+ mod_info[mod]["target_alphas"] = [alpha]
63
+
64
+ sampling_weights = None
65
+ else:
66
+ print(f"Loading alphas config from: {dataset_config['alphas_config']}")
67
+ with open(dataset_config['alphas_config'], "r") as f:
68
+ alphas_config = yaml.safe_load(f)
69
+
70
+ if 'sampling_weights' in alphas_config:
71
+ sampling_weights = alphas_config['sampling_weights']
72
+ alphas_config = alphas_config['alphas_mixture']
73
+ else:
74
+ sampling_weights = None
75
+
76
+ for mod in mod_info:
77
+ mod_info[mod]["input_alphas"] = alphas_config[mod]["input_alphas"]
78
+ mod_info[mod]["target_alphas"] = alphas_config[mod]["target_alphas"]
79
+ if modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']:
80
+ mod_info[mod]['keep'] = alphas_config[mod]['keep']
81
+
82
+ return mod_info, sampling_weights
83
+
84
+ def get_train_dataloader(dataset_config, modality_info, sampling_weights, text_tokenizer, input_size,
85
+ num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
86
+ num_tasks, num_workers, dataset_batch_size=None, epoch_size=None):
87
+
88
+ in_domains = sorted(list(dataset_config['in_domains'].split('-')))
89
+ out_domains = sorted(list(dataset_config['out_domains'].split('-')))
90
+ all_domains = sorted(list(set(in_domains) | set(out_domains)))
91
+
92
+ modality_transforms = MODALITY_TRANSFORMS
93
+ if 'caption' in modality_transforms:
94
+ modality_transforms['caption'] = CaptionTransform(
95
+ aligned_captions=dataset_config.get('aligned_captions', True)
96
+ )
97
+
98
+ if dataset_config['type'] == 'multimodal':
99
+
100
+ is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
101
+ if is_pretokenized:
102
+ # Multi-modal training data augmentation (uses pre-tokenized data augmentation)
103
+ image_augmenter = PreTokenizedImageAugmenter(
104
+ target_size=input_size,
105
+ no_aug=(not dataset_config.get('tok_train_aug', True)),
106
+ main_domain=dataset_config['main_augment_domain']
107
+ )
108
+ else:
109
+ image_augmenter = RandomCropImageAugmenter(
110
+ target_size=input_size,
111
+ hflip=dataset_config.get('hflip'),
112
+ crop_scale=tuple(dataset_config.get('crop_scale')),
113
+ crop_ratio=tuple(dataset_config.get('crop_ratio')),
114
+ )
115
+
116
+ # Input and target token ranges
117
+ num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
118
+ num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
119
+ min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
120
+ min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
121
+ min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
122
+ min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens
123
+
124
+
125
+ if dataset_config['use_wds']:
126
+ # Using webdataset
127
+ loader = build_wds_fm_pretraining_dataloader(
128
+ data_path=dataset_config['data_path'], all_domains=all_domains,
129
+ modality_info=modality_info, modality_transforms=modality_transforms,
130
+ image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
131
+ input_tokens_range=(min_input_tokens, num_input_tokens),
132
+ target_tokens_range=(min_target_tokens, num_target_tokens),
133
+ num_gpus=num_tasks, num_workers=num_workers,
134
+ batch_size=dataset_batch_size, epoch_size=epoch_size,
135
+ modality_name_map=dataset_config.get('modality_name_map', None),
136
+ shuffle_buffer_load=dataset_config.get('wds_shuffle_buffer_tar', 1_000),
137
+ shuffle_buffer_repeat=dataset_config.get('wds_shuffle_buffer_repeat', 1_000),
138
+ n_repeats=dataset_config.get('wds_n_repeats', 1),
139
+ sampling_weights=sampling_weights,
140
+ )
141
+ else:
142
+ dataset_train = build_fm_pretraining_dataset(
143
+ data_path=dataset_config['data_path'],
144
+ all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
145
+ image_augmenter=image_augmenter, text_tokenizer=text_tokenizer,
146
+ input_tokens_range=(min_input_tokens, num_input_tokens),
147
+ target_tokens_range=(min_target_tokens, num_target_tokens)
148
+ )
149
+ sampler_train = torch.utils.data.DistributedSampler(
150
+ dataset_train, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=True, drop_last=True,
151
+ )
152
+ # DataLoader has batch size 1 as it then gets collated through the Mixture dataloader
153
+ loader = torch.utils.data.DataLoader(
154
+ dataset_train, sampler=sampler_train,
155
+ batch_size=1, num_workers=0,
156
+ pin_memory=False, drop_last=True,
157
+ collate_fn=lambda x: x[0],
158
+ )
159
+
160
+ elif dataset_config['type'] == 'huggingface':
161
+
162
+ # Input and target token ranges
163
+ num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
164
+ num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
165
+
166
+ if dataset_config.get('use_wds', False):
167
+ raise NotImplementedError('Webdataset not yet implemented for huggingface datasets.')
168
+ else:
169
+ loader = build_huggingface_pretraining_dataloader(
170
+ data_path=dataset_config['data_path'], all_domains=all_domains,
171
+ modality_info=modality_info, modality_transforms=modality_transforms,
172
+ image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
173
+ input_tokens_range=(num_input_tokens, num_input_tokens),
174
+ target_tokens_range=(num_target_tokens, num_target_tokens),
175
+ num_gpus=num_tasks, num_workers=num_workers,
176
+ batch_size=dataset_batch_size, epoch_size=epoch_size,
177
+ split='train', streaming=True, rename_text_to_caption=True,
178
+ shuffle_buffer_load=dataset_config.get('shuffle_buffer_load', 1_000),
179
+ shuffle_seed=0,
180
+ )
181
+ else:
182
+ raise NotImplementedError(f'Dataset type {dataset_config["type"]} not implemented.')
183
+
184
+ return loader
185
+
186
+
187
+ def cfgs_get(key, val_config, dataset_name, train_configs, default=None):
188
+ """ Try to retrieve a key from the validation set config.
189
+ If it does not exist, default to retrieving it from the train set config
190
+ with the same dataset name.
191
+ """
192
+ return val_config.get(key, train_configs[dataset_name].get(key, default))
193
+
194
+
195
+ def get_val_dataloader(dataset_config, dataset_name, train_configs, modality_info, sampling_weights, text_tokenizer,
196
+ input_size, num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens,
197
+ fixed_eval, fixed_eval_input_tokens, fixed_eval_target_tokens,
198
+ dist_eval, num_tasks, num_workers, batch_size, pin_mem):
199
+
200
+ in_domains = sorted(list(cfgs_get('in_domains', dataset_config, dataset_name, train_configs).split('-')))
201
+ out_domains = sorted(list(cfgs_get('out_domains', dataset_config, dataset_name, train_configs).split('-')))
202
+ all_domains = sorted(list(set(in_domains) | set(out_domains)))
203
+
204
+ modality_transforms = MODALITY_TRANSFORMS
205
+ if 'caption' in modality_transforms:
206
+ modality_transforms['caption'] = CaptionTransform(
207
+ aligned_captions=cfgs_get('aligned_captions', dataset_config, dataset_name, train_configs, True)
208
+ )
209
+
210
+ dataset_type = cfgs_get('type', dataset_config, dataset_name, train_configs)
211
+
212
+ if dataset_type == 'multimodal':
213
+
214
+ main_augment_domain = cfgs_get('main_augment_domain', dataset_config, dataset_name, train_configs)
215
+ is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info])
216
+ if is_pretokenized:
217
+ eval_image_augmenter = PreTokenizedImageAugmenter(
218
+ target_size=input_size, no_aug=True, main_domain=main_augment_domain
219
+ )
220
+ else:
221
+ eval_image_augmenter = CenterCropImageAugmenter(
222
+ target_size=input_size, main_domain=main_augment_domain
223
+ )
224
+
225
+
226
+ if fixed_eval:
227
+ input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
228
+ target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
229
+ else:
230
+ # Input and target token ranges
231
+ num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
232
+ num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
233
+ min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens)
234
+ min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens)
235
+ min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens
236
+ min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens
237
+ input_tokens_range = (min_input_tokens, num_input_tokens)
238
+ target_tokens_range = (min_target_tokens, num_target_tokens)
239
+
240
+ dataset_val = build_fm_pretraining_dataset(
241
+ data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs),
242
+ all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
243
+ image_augmenter=eval_image_augmenter, text_tokenizer=text_tokenizer,
244
+ input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range
245
+ )
246
+
247
+ print("Warning: Eval stats may vary slightly as the masking applied on images is random.")
248
+ if dist_eval:
249
+ if len(dataset_val) % num_tasks != 0:
250
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
251
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
252
+ 'equal num of samples per-process.')
253
+ sampler_val = torch.utils.data.DistributedSampler(
254
+ dataset_val, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=False)
255
+ else:
256
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
257
+ loader = torch.utils.data.DataLoader(
258
+ dataset_val, sampler=sampler_val,
259
+ batch_size=batch_size,
260
+ num_workers=num_workers,
261
+ pin_memory=pin_mem,
262
+ drop_last=False,
263
+ )
264
+
265
+ elif dataset_type == 'huggingface':
266
+
267
+ if fixed_eval:
268
+ input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens)
269
+ target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens)
270
+ else:
271
+ # Input and target token ranges
272
+ num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens)
273
+ num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens)
274
+ input_tokens_range = (num_input_tokens, num_input_tokens)
275
+ target_tokens_range = (num_target_tokens, num_target_tokens)
276
+
277
+ loader = build_huggingface_pretraining_dataloader(
278
+ data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs),
279
+ all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms,
280
+ image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer,
281
+ input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
282
+ num_gpus=num_tasks, num_workers=num_workers,
283
+ batch_size=batch_size, epoch_size=None,
284
+ split='validation', streaming=True, rename_text_to_caption=True,
285
+ shuffle_buffer_load=cfgs_get('shuffle_buffer_load', dataset_config, dataset_name, train_configs, 1_000),
286
+ shuffle_seed=0,
287
+ )
288
+
289
+ else:
290
+ raise NotImplementedError(f'Dataset type {dataset_type} not implemented.')
291
+
292
+ return loader
fourm/data/transfer_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ def convert_samples_to_mod_dict(samples, input_mod, target_mod, num_input_tokens, num_target_tokens):
17
+ """Converts a sample (e.g. a batch of RGB images) to a mod dict that can be passed directly to FourM.
18
+ Assumes both the input modality and target modality are dense tasks.
19
+ """
20
+
21
+ B = samples.shape[0]
22
+ device = samples.device
23
+
24
+ if input_mod == target_mod:
25
+ assert(num_input_tokens == num_target_tokens)
26
+ mod_dict = {
27
+ input_mod: {
28
+ 'tensor': samples,
29
+ 'input_mask': torch.zeros((B, num_input_tokens), dtype=torch.bool, device=device),
30
+ 'target_mask': torch.zeros((B, num_target_tokens), dtype=torch.bool, device=device),
31
+ 'decoder_attention_mask': torch.zeros((B, num_target_tokens), dtype=torch.int, device=device),
32
+ },
33
+ }
34
+ mod_dict[input_mod]['decoder_attention_mask'][:, 0] = num_target_tokens
35
+
36
+ else:
37
+ mod_dict = {
38
+ input_mod: {
39
+ 'tensor': samples,
40
+ 'input_mask': torch.zeros((B, num_input_tokens), dtype=torch.bool, device=samples.device),
41
+ 'target_mask': torch.ones((B, num_input_tokens), dtype=torch.bool, device=samples.device),
42
+ 'decoder_attention_mask': torch.zeros((B, num_input_tokens), dtype=torch.int, device=samples.device),
43
+ },
44
+ target_mod: {
45
+ 'tensor': torch.zeros((B, num_target_tokens), dtype=torch.long, device=samples.device),
46
+ 'input_mask': torch.ones((B, num_target_tokens), dtype=torch.bool, device=samples.device),
47
+ 'target_mask': torch.zeros((B, num_target_tokens), dtype=torch.bool, device=samples.device),
48
+ 'decoder_attention_mask': torch.ones((B, num_target_tokens), dtype=torch.int, device=samples.device),
49
+ },
50
+ }
51
+ mod_dict[target_mod]['decoder_attention_mask'][:, 0] = num_target_tokens
52
+
53
+ return mod_dict
fourm/data/unified_datasets.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+ import io
16
+ import itertools
17
+ import os
18
+ import re
19
+ from functools import partial
20
+ from typing import Any, Callable, Dict, Iterable, List, Optional
21
+
22
+ import braceexpand
23
+ import numpy as np
24
+ import torch
25
+ import webdataset as wds
26
+ from PIL import Image
27
+ from torch.utils.data import IterableDataset
28
+ from torch.utils.data._utils.collate import default_collate
29
+ from torchvision import transforms
30
+ from webdataset.filters import pipelinefilter, reraise_exception
31
+ from webdataset.handlers import warn_and_continue
32
+
33
+ try:
34
+ # Optionally load huggingface datasets
35
+ from datasets import load_dataset
36
+ from datasets.distributed import split_dataset_by_node
37
+ except ImportError:
38
+ print("Huggingface datasets not installed. Please install with `pip install datasets`.")
39
+
40
+ from fourm.data.masking import TransferMasking, UnifiedMasking
41
+ from fourm.data.modality_transforms import (CropSettingsTransform, IdentityTransform,
42
+ MaskTransform, UnifiedDataTransform,
43
+ get_transform_key)
44
+ from fourm.data.multimodal_dataset_folder import MultiModalDatasetFolder
45
+ from fourm.utils.dist import get_rank, get_world_size
46
+
47
+
48
+ def build_fm_pretraining_dataset(
49
+ data_path, all_domains, modality_info, modality_transforms,
50
+ image_augmenter, text_tokenizer,
51
+ input_tokens_range, target_tokens_range,
52
+ sampling_weights=None):
53
+ """Builds the FourM pre-training dataset based on the given arguments.
54
+ This function should mainly used for smaller datasets (e.g. validation sets),
55
+ while large training sets should be loaded with build_wds_fm_pretraining_dataloader in webdataset format.
56
+
57
+ Args:
58
+ data_path: Path to the dataset.
59
+ all_domains: List of all modalities to be used.
60
+ modality_info: Dictionary containing information about the modalities.
61
+ modality_transforms: Dictionary containing the transforms for each modality.
62
+ image_augmenter: Image augmenter.
63
+ text_tokenizer: Text tokenizer (for sequence modalities).
64
+ input_tokens_range: Range of the input token budget.
65
+ target_tokens_range: Range of the target token budget.
66
+ sampling_weights: Sampling weights for the mixture of Dirichlet distributions.
67
+
68
+ Returns:
69
+ FourM pre-training dataset as a PyTorch Dataset.
70
+ """
71
+
72
+ transform = transforms.Compose([
73
+ UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter),
74
+ UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer,
75
+ input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
76
+ sampling_weights=sampling_weights),
77
+ ])
78
+
79
+ # Remove vq domains that require a tokenizer
80
+ modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)]
81
+ # If we are using a pre-tokenized modality, we default to pre-computed crop settings
82
+ if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]):
83
+ modalities_without_vq.append("crop_settings")
84
+ modality_transforms = copy.deepcopy(modality_transforms)
85
+ modality_transforms["crop_settings"] = CropSettingsTransform()
86
+
87
+ modality_paths = {mod: modality_info[mod]['path'] for mod in modality_info if modality_info[mod].get('path', None) is not None}
88
+
89
+ return MultiModalDatasetFolder(root=data_path, modalities=modalities_without_vq, modality_paths=modality_paths,
90
+ modality_transforms=modality_transforms, transform=transform)
91
+
92
+
93
+ def build_fm_transfer_dataset(
94
+ data_path, modality_info, transform, modality_transforms, all_domains,
95
+ load_mask_valid: bool = False, max_samples: Optional[int] = None,
96
+ pre_shuffle: bool = False, cache: bool = False):
97
+ """Builds the FourM transfer dataset based on the given arguments.
98
+
99
+ Args:
100
+ data_path: Path to the dataset.
101
+ modality_info: Dictionary containing information about the modalities.
102
+ transform: Transform to be applied to the dataset.
103
+ modality_transforms: Dictionary containing the transforms for each modality.
104
+ all_domains: List of all modalities to be used.
105
+ load_mask_valid: Whether to load the mask_valid "modality".
106
+ max_samples: Maximum number of samples to load.
107
+ pre_shuffle: Whether to shuffle the dataset before loading.
108
+ cache: Whether to cache the dataset in memory.
109
+
110
+ Returns:
111
+ FourM transfer dataset as a PyTorch Dataset.
112
+ """
113
+
114
+ # Remove vq domains that require a tokenizer
115
+ modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)]
116
+ # If we are using a pre-tokenized modality, we default to pre-computed crop settings
117
+ if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]):
118
+ modalities_without_vq.append("crop_settings")
119
+ modality_transforms = copy.deepcopy(modality_transforms)
120
+ modality_transforms["crop_settings"] = CropSettingsTransform()
121
+
122
+ if load_mask_valid:
123
+ modalities_without_vq.append("mask_valid")
124
+ modality_transforms = copy.deepcopy(modality_transforms)
125
+ modality_transforms["mask_valid"] = MaskTransform()
126
+
127
+ modality_paths = {mod: modality_info[mod]['path'] for mod in modality_info if modality_info[mod].get('path', None) is not None}
128
+
129
+ return MultiModalDatasetFolder(root=data_path, modalities=modalities_without_vq, modality_paths=modality_paths,
130
+ modality_transforms=modality_transforms, transform=transform, max_samples=max_samples,
131
+ pre_shuffle=pre_shuffle, cache=cache)
132
+
133
+
134
+ ### Webdatasets (wds) functions
135
+
136
+ def _keyless_map(data, f, handler=reraise_exception):
137
+ """Map samples without adding __key__."""
138
+ for sample in data:
139
+ try:
140
+ result = f(sample)
141
+ except Exception as exn:
142
+ if handler(exn):
143
+ continue
144
+ else:
145
+ break
146
+ if result is None:
147
+ continue
148
+ yield result
149
+
150
+ map = pipelinefilter(_keyless_map)
151
+
152
+ def check_dots(s):
153
+ if '.gz' in s:
154
+ return s.count('.') == 2
155
+ return s.count('.') == 1
156
+
157
+ def remove_ext_with_gz(s):
158
+ if s.endswith('.gz'):
159
+ s = s.replace(".gz", "")
160
+ return os.path.splitext(s)[0]
161
+
162
+ def wds_decoder(key, value):
163
+ if key == "png" or key.endswith(".png"):
164
+ img = Image.open(io.BytesIO(value))
165
+ return img
166
+ elif key == "jpg" or key.endswith(".jpg"):
167
+ img = Image.open(io.BytesIO(value))
168
+ return img
169
+ elif key == "jpeg" or key.endswith(".jpeg"):
170
+ img = Image.open(io.BytesIO(value))
171
+ return img
172
+ elif key == 'npy' or key.endswith("npy"):
173
+ content = np.load(io.BytesIO(value), allow_pickle=True)
174
+ # try:
175
+ # content = np.load(io.BytesIO(value))
176
+ # except:
177
+ # content = np.load(io.BytesIO(value), allow_pickle=True)
178
+ return content
179
+ elif key == "jpx" or key.endswith('.jpx'):
180
+ img = Image.open(io.BytesIO(value))
181
+ return img
182
+ elif 'output' in key:
183
+ return int(value)
184
+ else:
185
+ # If not an image, use the basic handlers (.txt, .json, .pickle, .npz, ...)
186
+ # See https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
187
+ return None
188
+
189
+ def repeat_fn(src, n_repeats=5):
190
+ """
191
+ Repeat each sample n_repeats times.
192
+ E.g. A B C ... repeated 3 times becomes A A A B B B C C C ...
193
+ Depending on the downstream application, a shuffle should be added after this.
194
+ """
195
+ for sample in src:
196
+ for _ in range(n_repeats):
197
+ yield sample
198
+
199
+ def remove_extensions(sample):
200
+ """
201
+ In webdatasets, we identify the type of a given modality by adding an extension
202
+ in the form f"{modality_name}.{modality_extension}", e.g. "rgb.jpg" or "caption.json".
203
+ This function removes them and returns a dictionary of {f"{modality_name}": modality}.
204
+ """
205
+ return {remove_ext_with_gz(k): v for k, v in sample.items()}
206
+
207
+ def filter_metadata(sample, metadata=['__key__', '__url__', 'file_name', 'class_name', 'class_idx']):
208
+ """ Filters out non-modality entries specified in metadata when loading tar files with webdatasets. """
209
+ return {k: v for k, v in sample.items() if k not in metadata}
210
+
211
+ def apply_modality_transforms(sample, modality_transforms):
212
+ """ Applies a dictionary of modality-specific transforms to a dictionary of modalities. """
213
+ return {k: (modality_transforms[get_transform_key(k)](v) if k in modality_transforms else v) for k, v in sample.items() }
214
+
215
+ def tok_to_int64(sample):
216
+ """
217
+ Pre-computed tokens are saved as int16, but we need them as int64 instead.
218
+ """
219
+ return {k: (v.astype('int64') if 'tok_' in k else v) for k, v in sample.items()}
220
+
221
+ def rename_modalities(sample, modality_paths):
222
+ """
223
+ Renames modalities to their corresponding names in modality_paths.
224
+ """
225
+ return {out_path: sample[loaded_path] for out_path, loaded_path in modality_paths.items()}
226
+
227
+ def extract_modality_names(s):
228
+ # Regular expression pattern to match anything enclosed in '{' and '}', and comma separated
229
+ pattern = r'\{([^}]*)\}'
230
+ match = re.search(pattern, s)
231
+ return match.group(1).split(',') if match else []
232
+
233
+ def identity(sample):
234
+ """ Identity function that does nothing. """
235
+ return sample
236
+
237
+ def multi_tarfile_samples(src_iter: Iterable[Dict[str, Any]],
238
+ modality_name_map: Dict[str, str] = None,
239
+ handler: Callable[[Exception], bool] = warn_and_continue):
240
+ """Webdataset does not support splitting up shards by modality, so we need to do this manually.
241
+ Usually, we would need to save all modalities in the same tar file, e.g. shard_root_train/{00000..12345}.tar,
242
+ where each shard contains 1000 samples and each sample contains all modalities.
243
+ This is not flexible when adding new modalities, so we instead save each modality in a separate tar file,
244
+ e.g. shard_root_train_rgb/{00000..12345}.tar, shard_root_train_caption/{00000..12345}.tar, etc., where each shard contains
245
+ again 1000 samples, but each sample contains only one modality. All samples in all shards have to be aligned.
246
+
247
+ This function takes an iterator over shard URLs, where we use brace expansion to specify multiple tar files per modality.
248
+ E.g. shard_root_train_[rgb,caption]/00123.tar will be expanded to shard_root_train_rgb/00123.tar and shard_root_train_caption/00123.tar,
249
+ and the samples from these two tar files will be combined into a single sample.
250
+
251
+ Args:
252
+ src_iter: Iterator over shards that *already brace expanded the shard numbers*,
253
+ e.g. {'url': 'shard_root_train_[rgb,caption]/00000.tar'}, {'url': 'shard_root_train_[rgb,caption]/00001.tar'}, ...
254
+ This function will also work when no square braces for multiple modalities are used, e.g. {'url': 'shard_root_train/00000.tar'}, ...
255
+ It can be a drop-in replacement for wds.tarfile_samples.
256
+ modality_name_map: Optional dictionary specifying a mapping from modality folder names to arbitrary other names.
257
+ handler: Function that handles exceptions. If it returns True, the shard is skipped. If it returns False, the function exits.
258
+
259
+ Yields:
260
+ Dictionary of aligned samples from all modalities.
261
+ """
262
+ for src in src_iter:
263
+
264
+ # Multi tar file URLs use brace expansion with square braces
265
+ multi_tar_urls = src['url'].translate(str.maketrans('[]', '{}'))
266
+ modality_names = extract_modality_names(multi_tar_urls)
267
+ if len(modality_names) == 0:
268
+ # Case where multi-modal braceexpand is not used, e.g. shard_dir/shard00000.tar
269
+ modality_names = [None]
270
+ multi_tar_urls = [multi_tar_urls]
271
+ elif len(modality_names) == 1:
272
+ # Brace expand doesn't work with a single entry, e.g. shard_dir/[foo]/shard00000.tar
273
+ multi_tar_urls = [multi_tar_urls.replace("{", "").replace("}", "")]
274
+ else:
275
+ # Remaining cases where multiple modalities are specified, e.g. shard_dir/[foo,bar]/shard00000.tar
276
+ multi_tar_urls = list(braceexpand.braceexpand(multi_tar_urls))
277
+
278
+ # Create tar iterators for shards of all modalities
279
+ tar_iters = [wds.tarfile_samples([{'url': tar_url}]) for tar_url in multi_tar_urls]
280
+
281
+ try:
282
+ # Loop over these iterators in parallel and combine the tar files from different modalities
283
+ for multi_tar_files in zip(*tar_iters):
284
+
285
+ merged_dict = {}
286
+ merged_dict['__key__'] = multi_tar_files[0]['__key__']
287
+ merged_dict['__url__'] = src['url']
288
+
289
+ for modality_name, modality_dict in zip(modality_names, multi_tar_files):
290
+ _key = modality_dict.pop('__key__')
291
+ _url = modality_dict.pop('__url__')
292
+
293
+ if _key != merged_dict['__key__']:
294
+ raise ValueError(f"Divergence detected! Trying to merge keys {_key} of {modality_name} and {merged_dict['__key__']} of merged_dict with modalities {merged_dict.keys()}.")
295
+
296
+ tar_is_multimodal = len(modality_dict) > 1
297
+ for k, v in modality_dict.items():
298
+ if tar_is_multimodal or check_dots(k) or modality_name is None:
299
+ # We don't change the keys in the following cases:
300
+ # 1. The shard contains multiple modalities. Then they *have* to follow the idx.modality_id.ext convention
301
+ # 2. If any key contains a dot, this means it already has the idx.modality_id.ext format (idx. is already removed at this stage)
302
+ # 3. If the modality name is None, no modality folder was specified (see beginning of function)
303
+ merged_dict[k] = v
304
+ else:
305
+ mapped_name = modality_name if modality_name_map is None else modality_name_map.get(modality_name, modality_name)
306
+ merged_dict[f'{mapped_name}.{k}'] = v
307
+
308
+ yield merged_dict
309
+
310
+ except Exception as e:
311
+ print(e)
312
+ print(f"Exception occurred while processing {src['url']}.")
313
+ if handler(e):
314
+ print('Skipping shard...')
315
+ continue
316
+ else:
317
+ break
318
+
319
+ def build_wds_fm_pretraining_dataloader(
320
+ data_path, all_domains, modality_info, modality_transforms, image_augmenter,
321
+ text_tokenizer, input_tokens_range, target_tokens_range,
322
+ num_gpus, num_workers, batch_size, epoch_size, sampling_weights=None, modality_name_map=None,
323
+ shuffle_buffer_load=1000, shuffle_buffer_repeat=5000, n_repeats=5):
324
+ """Builds the WebDataset FourM pre-training dataloader based on the given arguments.
325
+
326
+ Args:
327
+ data_path: Path to the dataset.
328
+ all_domains: List of all modalities to be used.
329
+ modality_info: Dictionary containing information about the modalities.
330
+ modality_transforms: Dictionary containing the transforms for each modality.
331
+ image_augmenter: Image augmenter.
332
+ text_tokenizer: Text tokenizer (for sequence modalities).
333
+ input_tokens_range: Range of the input token budget.
334
+ target_tokens_range: Range of the target token budget.
335
+ num_gpus: Number of GPUs.
336
+ num_workers: Number of workers.
337
+ batch_size: Batch size.
338
+ epoch_size: Number of samples per "epoch". (Here, epoch refers to an interrupted training loop without evaluation or checkpointing).
339
+ sampling_weights: Sampling weights for the mixture of Dirichlet distributions.
340
+ modality_name_map: Optional dictionary specifying a mapping from modality folder names to arbitrary other names.
341
+ shuffle_buffer_load: Shuffle buffer size when loading samples from tar files (first shuffle).
342
+ shuffle_buffer_repeat: Shuffle buffer size after repeating samples (second shuffle).
343
+ n_repeats: Number of times to repeat each sample.
344
+
345
+ Returns:
346
+ FourM pre-training dataloader as a WebDataset DataLoader.
347
+ """
348
+
349
+ modality_paths = {mod: modality_info[mod].get('path', None) or mod for mod in modality_info}
350
+
351
+ # Remove vq domains that require a tokenizer
352
+ modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)]
353
+ # If we are using a pre-tokenized modality, we default to pre-computed crop settings
354
+ if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]):
355
+ modalities_without_vq.append("crop_settings")
356
+ modality_transforms = copy.deepcopy(modality_transforms)
357
+ modality_transforms["crop_settings"] = CropSettingsTransform()
358
+ modality_paths["crop_settings"] = "crop_settings"
359
+
360
+ # Webdatasets always adds __key__ to the dictionary, so we add a transform that does nothing with it
361
+ modality_transforms["__key__"] = IdentityTransform()
362
+
363
+ transform = transforms.Compose([
364
+ UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter),
365
+ UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer,
366
+ input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range,
367
+ sampling_weights=sampling_weights)
368
+ ])
369
+
370
+ datapipe = wds.DataPipeline(
371
+ # Infinitely sample shards from the shard list with replacement. Each worker is seeded independently.
372
+ wds.ResampledShards(data_path),
373
+ partial(multi_tarfile_samples, modality_name_map=modality_name_map), # Extract individual samples from single or multi-modal tar files
374
+ wds.shuffle(shuffle_buffer_load), # Shuffle with a buffer of given size
375
+ wds.decode(wds_decoder), # Decode from bytes to PIL images, numpy arrays, etc.
376
+ wds.filters.compose(partial(repeat_fn, n_repeats=n_repeats)), # Repeats each sample n times -> A A A B B B C C C ...
377
+ wds.shuffle(shuffle_buffer_repeat), # Shuffle again with a buffer of given size
378
+ wds.map(remove_extensions), # Remove "file extensions" from dictionary keys
379
+ map(filter_metadata), # Remove non-task keys
380
+ map(tok_to_int64), # Convert pre-computed tokens to int64
381
+ map(partial(rename_modalities, modality_paths=modality_paths)), # Rename modalities to their corresponding names in modality_paths
382
+ map(transform), # Apply data augmentation and masking
383
+ wds.batched(batch_size, collation_fn=default_collate, partial=False)
384
+ if batch_size is not None else map(identity), # Batching
385
+ )
386
+
387
+ if epoch_size is not None:
388
+ batch_size_iter = batch_size if batch_size is not None else 1
389
+ datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length
390
+
391
+ if batch_size is not None:
392
+ # Perform multi-threaded dataloading
393
+ return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None)
394
+ else:
395
+ return datapipe
396
+
397
+
398
+ def build_wds_divae_dataloader(
399
+ data_path, modality_info, modality_transforms, image_augmenter,
400
+ num_gpus, num_workers, batch_size, epoch_size, shuffle_buffer_load=1000,
401
+ shuffle_buffer_repeat=5000, n_repeats=1):
402
+
403
+ modality_paths = {mod: modality_info[mod].get('path', None) or mod for mod in modality_info}
404
+
405
+ # Webdatasets always adds __key__ to the dictionary, so we add a transform that does nothing with it
406
+ modality_transforms["__key__"] = IdentityTransform()
407
+
408
+ transform = UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter)
409
+
410
+ datapipe = wds.DataPipeline(
411
+ # Infinitely sample shards from the shard list with replacement. Each worker is seeded independently.
412
+ wds.ResampledShards(data_path),
413
+ multi_tarfile_samples, # Extract individual samples from single or multi-modal tar files
414
+ wds.shuffle(shuffle_buffer_load), # Shuffle with a buffer of given size
415
+ wds.decode(wds_decoder), # Decode from bytes to PIL images, numpy arrays, etc.
416
+ wds.filters.compose(partial(repeat_fn, n_repeats=n_repeats)), # Repeats each sample n times -> A A A B B B C C C ...
417
+ wds.shuffle(shuffle_buffer_repeat), # Shuffle again with a buffer of given size
418
+ map(remove_extensions), # Remove "file extensions" from dictionary keys
419
+ map(filter_metadata), # Remove non-task keys
420
+ map(tok_to_int64), # Convert pre-computed tokens to int64
421
+ map(partial(rename_modalities, modality_paths=modality_paths)), # Rename modalities to their corresponding names in modality_paths
422
+ map(transform), # Apply data augmentation and masking
423
+ wds.batched(batch_size, collation_fn=default_collate, partial=False)
424
+ if batch_size is not None else map(identity), # Batching
425
+ )
426
+
427
+ if epoch_size is not None:
428
+ batch_size_iter = batch_size if batch_size is not None else 1
429
+ datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length
430
+
431
+ if batch_size is not None:
432
+ # Perform multi-threaded dataloading
433
+ return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None)
434
+ else:
435
+ return datapipe
436
+
437
+
438
+ ### Huggingface datasets functions
439
+
440
+ def text_to_caption(sample):
441
+ """ Rename "text" to "caption". """
442
+ return {'caption': sample['text']}
443
+
444
+
445
+ def build_huggingface_pretraining_dataloader(
446
+ data_path, all_domains, modality_info, modality_transforms, image_augmenter,
447
+ text_tokenizer, input_tokens_range, target_tokens_range,
448
+ num_gpus, num_workers, batch_size, epoch_size, split,
449
+ streaming=True, rename_text_to_caption=True, shuffle_buffer_load=10_000, shuffle_seed=0):
450
+
451
+ # Load huggingface dataset and split samples across workers. Shuffle samples in each worker
452
+ dataset = load_dataset(data_path, split=split, streaming=streaming)
453
+ dataset = split_dataset_by_node(dataset, rank=get_rank(), world_size=get_world_size())
454
+ dataset = dataset.shuffle(seed=shuffle_seed, buffer_size=shuffle_buffer_load)
455
+
456
+ modality_info = {mod: modality_info[mod] for mod in modality_info if mod in all_domains}
457
+
458
+ transform = transforms.Compose([
459
+ UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter),
460
+ UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer,
461
+ input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range)
462
+ ])
463
+
464
+ datapipe = wds.DataPipeline(
465
+ dataset,
466
+ map(text_to_caption) if rename_text_to_caption else map(identity), # Rename "text" to "caption"
467
+ map(filter_metadata), # Remove non-task keys
468
+ map(transform), # Apply data augmentation and masking
469
+ wds.batched(batch_size, collation_fn=default_collate, partial=False)
470
+ if batch_size is not None else map(identity), # Batching
471
+ )
472
+
473
+ datapipe.n_shards = dataset.n_shards
474
+ num_workers = min(num_workers, dataset.n_shards)
475
+
476
+ if epoch_size is not None:
477
+ batch_size_iter = batch_size if batch_size is not None else 1
478
+ datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length
479
+
480
+ if batch_size is not None:
481
+ # Perform multi-threaded dataloading
482
+ return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None)
483
+ else:
484
+ return datapipe
485
+
486
+
487
+ ### Multi-dataset loading utils
488
+ def make_empty_mod_dict(modality_info):
489
+ empty_mod_dicts = {}
490
+
491
+ for mod_name, mod_info in modality_info.items():
492
+ empty_mod = {}
493
+
494
+ # Tensor
495
+ if 'num_channels' in mod_info and 'input_size' in mod_info:
496
+ # Handle image-like modalities
497
+ max_tokens = mod_info['max_tokens']
498
+ empty_mod['tensor'] = torch.zeros((mod_info['num_channels'], mod_info['input_size'], mod_info['input_size']), dtype=torch.float32)
499
+ elif mod_name == 't5_caption':
500
+ # Handle T5 embedding
501
+ max_tokens = mod_info['max_tokens']
502
+ orig_emb_dim = mod_info['encoder_embedding']().orig_emb_dim
503
+ empty_mod['tensor'] = torch.zeros((max_tokens, orig_emb_dim), dtype=torch.float32)
504
+ elif mod_info['type'] in ['seq', 'seq_emb', 'seq_token']:
505
+ # Handle all other discrete sequence modalities
506
+ max_tokens = (mod_info['max_tokens'] + 1) * 2
507
+ empty_mod['tensor'] = torch.zeros((max_tokens), dtype=torch.int32)
508
+ else:
509
+ max_tokens = mod_info['max_tokens']
510
+ empty_mod['tensor'] = torch.zeros((max_tokens), dtype=torch.int32)
511
+
512
+ # Input and target masks
513
+ empty_mod['input_mask'] = torch.ones((max_tokens), dtype=torch.bool)
514
+ empty_mod['target_mask'] = torch.ones((max_tokens), dtype=torch.bool)
515
+
516
+ # Decoder attention mask
517
+ empty_mod['decoder_attention_mask'] = torch.zeros((max_tokens), dtype=torch.int32)
518
+
519
+ empty_mod_dicts[mod_name] = empty_mod
520
+
521
+ return empty_mod_dicts
522
+
523
+
524
+ class MixtureDataset(IterableDataset):
525
+ def __init__(self, data_iters, weights, modality_info):
526
+ self.orig_data_iters = data_iters
527
+ self.data_iters = [iter(data_iter) for data_iter in data_iters] # Create initial iterators
528
+ self.sampling_probs = np.array(weights) / sum(weights)
529
+ self.modality_info = modality_info
530
+
531
+ def reset_iterator(self, idx):
532
+ """ Reset the iterator when exhausted. """
533
+ self.data_iters[idx] = iter(self.orig_data_iters[idx])
534
+
535
+ def __iter__(self):
536
+ while True:
537
+ dataset_idx = np.random.choice(len(self.sampling_probs), p=self.sampling_probs)
538
+ try:
539
+ data = next(self.data_iters[dataset_idx])
540
+ except StopIteration: # If the iterator is exhausted
541
+ self.reset_iterator(dataset_idx) # Reset it
542
+ data = next(self.data_iters[dataset_idx])
543
+
544
+ mod_dict = make_empty_mod_dict(self.modality_info)
545
+ mod_dict.update(data)
546
+ yield mod_dict
547
+
548
+
549
+ def build_mixture_dataloader(data_iters, weights, modality_info, batch_size, num_workers, epoch_size, num_gpus):
550
+ mixture_pipe = wds.DataPipeline(
551
+ MixtureDataset(data_iters, weights, modality_info),
552
+ wds.batched(batch_size, collation_fn=default_collate, partial=False),
553
+ ).with_epoch(epoch_size // (num_gpus * num_workers * batch_size)) # Pre-define iterator length
554
+
555
+ mixture_loader = wds.WebLoader(mixture_pipe, num_workers=num_workers, batch_size=None)
556
+
557
+ return mixture_loader
fourm/demo_4M_sampler.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import os
3
+ import math
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import requests
9
+ from tokenizers import Tokenizer
10
+ import matplotlib.pyplot as plt
11
+ from torchvision.transforms.functional import center_crop
12
+
13
+ from fourm.models.fm import FM
14
+ from fourm.vq.vqvae import VQVAE, DiVAE
15
+ from fourm.models.generate import GenerationSampler, build_chained_generation_schedules, init_empty_target_modality, init_full_input_modality, custom_text
16
+ from fourm.utils.plotting_utils import decode_dict
17
+ from fourm.data.modality_info import MODALITY_INFO
18
+ from fourm.data.modality_transforms import RGBTransform
19
+ from fourm.utils import load_safetensors
20
+ from fourm.utils.plotting_utils import decode_dict, visualize_bboxes, plot_text_in_square, text_to_pil_image
21
+
22
+ # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
25
+ torch.backends.cudnn.allow_tf32 = True
26
+
27
+
28
+ # Default chained generation order
29
+ DEFAULT_ORDER = [
30
+ 'tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224', 'tok_depth@224', 'tok_normal@224',
31
+ 'tok_semseg@224', 'tok_canny_edge@224', 'tok_sam_edge@224', 'tok_rgb@224',
32
+ 'caption', 'det', 'human_poses', 'sam_instance', 'color_palette', 'metadata',
33
+ ]
34
+
35
+ # Default super-resolution chained generation order
36
+ DEFAULT_ORDER_SR = [
37
+ 'tok_clip@448', 'tok_depth@448', 'tok_normal@448',
38
+ 'tok_semseg@448', 'tok_rgb@448',
39
+ ]
40
+
41
+ # Default generation parameters for the case where the input contains RGB
42
+ DEFAULTS_RGB2X = {
43
+ 'tok_clip@224/tok_depth@224/tok_normal@224/tok_semseg@224/tok_canny_edge@224/tok_sam_edge@224': {
44
+ 'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 1,
45
+ 'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant',
46
+ 'cfg_scale': 2.0, 'cfg_schedule': 'constant',
47
+ },
48
+ 'tok_dinov2@224/tok_imagebind@224': {
49
+ 'tokens_per_target': 256, 'autoregression_scheme': 'roar', 'decoding_steps': 1,
50
+ 'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant',
51
+ 'cfg_scale': 2.0, 'cfg_schedule': 'constant',
52
+ },
53
+ 'caption/det': {
54
+ 'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
55
+ 'token_decoding_schedule': None, 'temp': 0.3, 'temp_schedule': 'constant',
56
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
57
+ },
58
+ 'human_poses': {
59
+ 'tokens_per_target': 275, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
60
+ 'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
61
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
62
+ },
63
+ 'sam_instance': {
64
+ 'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
65
+ 'token_decoding_schedule': None, 'temp': 0.01, 'temp_schedule': 'constant',
66
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
67
+ },
68
+ 'color_palette': {
69
+ 'tokens_per_target': 23, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
70
+ 'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
71
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
72
+ },
73
+ 'metadata': {
74
+ 'tokens_per_target': 40, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
75
+ 'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
76
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
77
+ },
78
+ }
79
+
80
+ # Default generation parameters for the case where the target is RGB
81
+ DEFAULTS_X2RGB = {
82
+ 'tok_clip@224': {
83
+ 'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 50,
84
+ 'token_decoding_schedule': 'linear', 'temp': 5.0, 'temp_schedule': 'onex:0.5:0.5',
85
+ 'cfg_scale': 3.0, 'cfg_schedule': 'constant',
86
+ },
87
+ 'tok_dinov2@224/tok_imagebind@224': {
88
+ 'tokens_per_target': 256, 'autoregression_scheme': 'roar', 'decoding_steps': 8,
89
+ 'token_decoding_schedule': 'linear', 'temp': 0.01, 'temp_schedule': 'constant',
90
+ 'cfg_scale': 2.0, 'cfg_schedule': 'constant',
91
+ },
92
+ 'tok_depth@224/tok_normal@224/tok_semseg@224/tok_canny_edge@224/tok_sam_edge@224': {
93
+ 'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 8,
94
+ 'token_decoding_schedule': 'linear', 'temp': 3.0, 'temp_schedule': 'onex:0.5:0.5',
95
+ 'cfg_scale': 2.0, 'cfg_schedule': 'constant',
96
+ },
97
+ 'tok_rgb@224': {
98
+ 'tokens_per_target': 196, 'autoregression_scheme': 'roar', 'decoding_steps': 25,
99
+ 'token_decoding_schedule': 'linear', 'temp': 3.0, 'temp_schedule': 'onex:0.5:0.5',
100
+ 'cfg_scale': 2.0, 'cfg_schedule': 'constant',
101
+ },
102
+ 'caption/det': {
103
+ 'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
104
+ 'token_decoding_schedule': None, 'temp': 0.3, 'temp_schedule': 'constant',
105
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
106
+ },
107
+ 'human_poses': {
108
+ 'tokens_per_target': 275, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
109
+ 'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
110
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
111
+ },
112
+ 'sam_instance': {
113
+ 'tokens_per_target': 256, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
114
+ 'token_decoding_schedule': None, 'temp': 0.01, 'temp_schedule': 'constant',
115
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
116
+ },
117
+ 'color_palette': {
118
+ 'tokens_per_target': 23, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
119
+ 'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
120
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
121
+ },
122
+ 'metadata': {
123
+ 'tokens_per_target': 40, 'autoregression_scheme': 'autoregressive', 'decoding_steps': None,
124
+ 'token_decoding_schedule': None, 'temp': 0.1, 'temp_schedule': 'constant',
125
+ 'cfg_scale': 1.0, 'cfg_schedule': 'constant',
126
+ },
127
+ }
128
+
129
+ # Default generation parameters for super-resolution
130
+ DEFAULTS_SR = {
131
+ 'tok_clip@448/tok_depth@448/tok_normal@448/tok_semseg@448/tok_rgb@448': {
132
+ 'tokens_per_target': 784, 'autoregression_scheme': 'maskgit', 'decoding_steps': 8,
133
+ 'token_decoding_schedule': 'cosine', 'temp': 1.0, 'temp_schedule': 'constant',
134
+ 'cfg_scale': 2.0, 'cfg_schedule': 'constant',
135
+ },
136
+ }
137
+
138
+ # Plotting names for each modality
139
+ MODALITY_PLOTTING_NAME_MAP = {
140
+ 'caption': 'Caption',
141
+ 'det': 'Bounding boxes',
142
+ 'human_poses': 'Human poses',
143
+ 'sam_instance': 'SAM instances (single pass)',
144
+ 'color_palette': 'Color palette',
145
+ 'metadata': 'Metadata',
146
+ 'rgb@224': 'RGB (224x224)',
147
+ 'rgb@448': 'RGB (448x448)',
148
+ 'tok_rgb@224': 'RGB (tokenized, 224x224)',
149
+ 'tok_rgb@448': 'RGB (tokenized, 448x448)',
150
+ 'tok_clip@224': 'CLIP-B/16 (224x224)',
151
+ 'tok_clip@448': 'CLIP-B/16 (448x448)',
152
+ 'tok_depth@224': 'Depth (224x224)',
153
+ 'tok_depth@448': 'Depth (448x448)',
154
+ 'tok_normal@224': 'Normals (224x224)',
155
+ 'tok_normal@448': 'Normals (448x448)',
156
+ 'tok_semseg@224': 'Semantic segmentation (224x224)',
157
+ 'tok_semseg@448': 'Semantic segmentation (448x448)',
158
+ 'tok_canny_edge@224': 'Canny edges (224x224)',
159
+ 'tok_sam_edge@224': 'SAM edges (224x224)',
160
+ 'tok_dinov2@224': 'DINOv2-B/14 (224x224)',
161
+ 'tok_imagebind@224': 'ImageBind-H/14 (224x224)',
162
+ }
163
+
164
+ # Optional fixed plotting order (by default, plotting order is determined by generation order)
165
+ MODALITY_PLOTTING_ORDER = [
166
+ 'rgb@224', 'rgb@448', 'tok_rgb@224', 'tok_rgb@448',
167
+ 'tok_depth@224', 'tok_depth@448', 'tok_normal@224', 'tok_normal@448',
168
+ 'tok_semseg@224', 'tok_semseg@448', 'tok_canny_edge@224', 'tok_sam_edge@224',
169
+ 'sam_instance', 'human_poses', 'det', 'caption', 'metadata', 'color_palette',
170
+ 'tok_clip@224', 'tok_clip@448', 'tok_dinov2@224', 'tok_imagebind@224',
171
+ ]
172
+
173
+
174
+ def get_value(defaults_dict, domain, key):
175
+ """Look up a default value belonging to a given domain and key."""
176
+ for domains, defaults in defaults_dict.items():
177
+ if domain in domains:
178
+ return defaults[key]
179
+
180
+ def load_model(model_id, model_class):
181
+ """Load a model from HuggingFace hub or a given .safetensors checkpoint path."""
182
+ if model_id.endswith('.safetensors'):
183
+ ckpt, config = load_safetensors(model_id)
184
+ model = model_class(config=config)
185
+ model.load_state_dict(ckpt)
186
+ else:
187
+ model = model_class.from_pretrained(model_id)
188
+ return model
189
+
190
+ def img_from_url(url: str):
191
+ rgb_transform = RGBTransform(imagenet_default_mean_and_std=True)
192
+ img_data = requests.get(url).content
193
+ with open('demo.png', 'wb') as handler:
194
+ handler.write(img_data)
195
+ img_pil = rgb_transform.load('./demo.png')
196
+ img_pil = rgb_transform.preprocess(img_pil)
197
+ img_pil = center_crop(img_pil, (min(img_pil.size), min(img_pil.size))).resize((224,224))
198
+ img = rgb_transform.postprocess(img_pil).unsqueeze(0)
199
+ return img
200
+
201
+
202
+ class Demo4MSampler(nn.Module):
203
+ """Convenience wrapper for easy 4M loading and generation. Users can specify HuggingFace Hub
204
+ model URLs, or downloaded safetensors checkpoints paths, and the models will be automatically
205
+ loaded. The `forward` function can be used for RGB-2-all and {caption,det}-2-all generation.
206
+ This wrapper is only intended for quickly trying out 4M models. For more advanced usecases we
207
+ recommend looking at the generation notebooks in `./notebooks/`, and `./run_generation.py`.
208
+
209
+ Args:
210
+ fm: Hub or safetensors path of 4M base model
211
+ fm_sr: Hub or safetensors path of 4M super-resolution model
212
+ tok_rgb: Hub or safetensors path of RGB tokenizer
213
+ tok_depth: Hub or safetensors path of depth tokenizer
214
+ tok_normal: Hub or safetensors path of surface normal tokenizer
215
+ tok_edge: Hub or safetensors path of canny edge tokenizer (for SAM and RGB edges)
216
+ tok_semseg: Hub or safetensors path of COCO semantic segmentation tokenizer
217
+ tok_clip: Hub or safetensors path of CLIP-B/16 tokenizer
218
+ tok_dinov2: Hub or safetensors path of DINOv2-B/14 tokenizer
219
+ tok_imagebind: Hub or safetensors path of ImageBind-H/14 tokenizer
220
+ tok_sam_instance: Hub or safetensors path of SAM instance tokenizer
221
+ tok_human_poses: Hub or safetensors path of human poses tokenizer
222
+ tok_text: Path to text tokenizer JSON file
223
+ mods: Optional list of modalities to override default behavior of generating everything
224
+ mods_sr: Optional list of super-res modalities to override default behavior of generating everything
225
+ """
226
+ def __init__(self,
227
+ fm: str = 'EPFL-VILAB/4M-21_XL_CC12M',
228
+ fm_sr: Optional[str] = 'EPFL-VILAB/4M-7-SR_L_CC12M',
229
+ tok_rgb: Optional[str] = 'EPFL-VILAB/4M_tokenizers_rgb_16k_224-448',
230
+ tok_depth: Optional[str] = 'EPFL-VILAB/4M_tokenizers_depth_8k_224-448',
231
+ tok_normal: Optional[str] = 'EPFL-VILAB/4M_tokenizers_normal_8k_224-448',
232
+ tok_edge: Optional[str] = 'EPFL-VILAB/4M_tokenizers_edge_8k_224-512',
233
+ tok_semseg: Optional[str] = 'EPFL-VILAB/4M_tokenizers_semseg_4k_224-448',
234
+ tok_clip: Optional[str] = 'EPFL-VILAB/4M_tokenizers_CLIP-B16_8k_224-448',
235
+ tok_dinov2: Optional[str] = 'EPFL-VILAB/4M_tokenizers_DINOv2-B14_8k_224-448',
236
+ tok_imagebind: Optional[str] = 'EPFL-VILAB/4M_tokenizers_ImageBind-H14_8k_224-448',
237
+ tok_sam_instance: Optional[str] = 'EPFL-VILAB/4M_tokenizers_sam-instance_1k_64',
238
+ tok_human_poses: Optional[str] = 'EPFL-VILAB/4M_tokenizers_human-poses_1k_8',
239
+ tok_text: str = './fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json',
240
+ mods: Optional[List[str]] = None,
241
+ mods_sr: Optional[List[str]] = None,
242
+ verbose: bool = True):
243
+ super().__init__()
244
+
245
+ self.verbose = verbose
246
+ if self.verbose:
247
+ print('Loading 4M models and tokenizers...', end='')
248
+
249
+ # Load 4M model and initialize sampler
250
+ fm = load_model(fm, FM)
251
+ self.sampler_fm = GenerationSampler(fm)
252
+ self.mods = mods or list(set(fm.encoder_modalities) | set(fm.decoder_modalities))
253
+
254
+ # Load optional 4M super-res model and initialize sampler
255
+ if fm_sr is not None:
256
+ fm_sr = load_model(fm_sr, FM)
257
+ self.sampler_fm_sr = GenerationSampler(fm_sr)
258
+ self.mods_sr = mods_sr or list(set(fm_sr.encoder_modalities) | set(fm_sr.decoder_modalities))
259
+ else:
260
+ self.sampler_fm_sr = None
261
+
262
+ # Load tokenizers
263
+ self.toks = {}
264
+ if ('tok_rgb@224' in self.mods or 'tok_rgb@448' in self.mods_sr) and tok_rgb is not None:
265
+ self.toks['tok_rgb'] = load_model(tok_rgb, DiVAE)
266
+ if ('tok_depth@224' in self.mods or 'tok_depth@448' in self.mods_sr) and tok_depth is not None:
267
+ self.toks['tok_depth'] = load_model(tok_depth, DiVAE)
268
+ if ('tok_normal@224' in self.mods or 'tok_normal@448' in self.mods_sr) and tok_normal is not None:
269
+ self.toks['tok_normal'] = load_model(tok_normal, DiVAE)
270
+ if ('tok_canny_edge@224' in self.mods or 'tok_sam_edge@224' in self.mods) and tok_edge is not None:
271
+ self.toks['tok_canny_edge'] = load_model(tok_edge, DiVAE)
272
+ self.toks['tok_sam_edge'] = self.toks['tok_canny_edge'] # Shared tokenizer
273
+ if ('tok_semseg@224' in self.mods or 'tok_semseg@448' in self.mods_sr) and tok_semseg is not None:
274
+ self.toks['tok_semseg'] = load_model(tok_semseg, VQVAE)
275
+ if ('tok_clip@224' in self.mods or 'tok_clip@448' in self.mods_sr) and tok_clip is not None:
276
+ self.toks['tok_clip'] = load_model(tok_clip, VQVAE)
277
+ if 'tok_dinov2@224' in self.mods and tok_dinov2 is not None:
278
+ self.toks['tok_dinov2'] = load_model(tok_dinov2, VQVAE)
279
+ if 'tok_imagebind@224' in self.mods and tok_imagebind is not None:
280
+ self.toks['tok_imagebind'] = load_model(tok_imagebind, VQVAE)
281
+ if 'sam_instance' in self.mods and tok_sam_instance is not None:
282
+ self.toks['sam_instance'] = load_model(tok_sam_instance, VQVAE)
283
+ if 'human_poses' in self.mods and tok_human_poses is not None:
284
+ self.toks['human_poses'] = load_model(tok_human_poses, VQVAE)
285
+ self.toks = nn.ModuleDict(self.toks)
286
+ self.tok_text = Tokenizer.from_file(tok_text)
287
+
288
+ if self.verbose:
289
+ print(' done!')
290
+
291
+ @property
292
+ def device(self):
293
+ return next(self.parameters()).device
294
+
295
+ def __setup_conds_and_targets(self, sample):
296
+ # Input and output modalities
297
+ cond_domains = [domain for domain in list(sample.keys()) if domain in self.mods]
298
+ target_domains = [domain for domain in DEFAULT_ORDER if (domain not in cond_domains and domain in self.mods)]
299
+ if 'rgb@224' in cond_domains:
300
+ # Do not generate tokenized RGB if pixel RGB is given as input
301
+ target_domains.remove('tok_rgb@224')
302
+ return cond_domains, target_domains
303
+
304
+ def __setup_sr_conds_and_targets(self, sample):
305
+ cond_domains_sr = [domain for domain in list(sample.keys()) if domain in self.mods_sr]
306
+ target_domains_sr = [domain for domain in DEFAULT_ORDER_SR if (domain.replace('448', '224') in cond_domains_sr and domain in self.mods_sr)]
307
+ return cond_domains_sr, target_domains_sr
308
+
309
+ def __setup_sample_and_schedule(self, sample, cond_domains, target_domains, cfg_grow_conditioning=True):
310
+ # 1 - Setup generation schedule
311
+
312
+ defaults = DEFAULTS_RGB2X if ('rgb@224' in cond_domains or 'tok_rgb@224' in cond_domains) else DEFAULTS_X2RGB
313
+
314
+ tokens_per_target = [get_value(defaults, domain, 'tokens_per_target') for domain in target_domains]
315
+ autoregression_schemes = [get_value(defaults, domain, 'autoregression_scheme') for domain in target_domains]
316
+ decoding_steps = [get_value(defaults, domain, 'decoding_steps') for domain in target_domains]
317
+ token_decoding_schedules = [get_value(defaults, domain, 'token_decoding_schedule') for domain in target_domains]
318
+ temps = [get_value(defaults, domain, 'temp') for domain in target_domains]
319
+ temp_schedules = [get_value(defaults, domain, 'temp_schedule') for domain in target_domains]
320
+ cfg_scales = [get_value(defaults, domain, 'cfg_scale') for domain in target_domains]
321
+ cfg_schedules = [get_value(defaults, domain, 'cfg_schedule') for domain in target_domains]
322
+
323
+ schedule = build_chained_generation_schedules(
324
+ cond_domains=cond_domains, target_domains=target_domains, tokens_per_target=tokens_per_target,
325
+ autoregression_schemes=autoregression_schemes, decoding_steps=decoding_steps,
326
+ token_decoding_schedules=token_decoding_schedules, temps=temps, temp_schedules=temp_schedules,
327
+ cfg_scales=cfg_scales, cfg_schedules=cfg_schedules, cfg_grow_conditioning=cfg_grow_conditioning,
328
+ )
329
+
330
+ # 2 - Setup sample
331
+
332
+ sample_dict = {}
333
+
334
+ # Handle special cases
335
+ if 'caption' in sample:
336
+ caption = sample.pop('caption')
337
+ sample_dict = custom_text(
338
+ sample_dict, input_text=caption, eos_token='[EOS]',
339
+ key='caption', device=self.device, text_tokenizer=self.tok_text
340
+ )
341
+ if 'det' in sample:
342
+ caption = sample.pop('det')
343
+ sample_dict = custom_text(
344
+ sample_dict, input_text=caption, eos_token='[EOS]',
345
+ key='det', device=self.device, text_tokenizer=self.tok_text
346
+ )
347
+ # Add remaining modalities
348
+ sample_dict.update({domain: {'tensor': tensor} for domain, tensor in sample.items()})
349
+
350
+ # Initialize these remaining input modalities (caption and det are already initialized by custom_text)
351
+ for cond_mod in sample.keys():
352
+ sample_dict = init_full_input_modality(sample_dict, MODALITY_INFO, cond_mod, self.device, eos_id=self.tok_text.token_to_id("[EOS]"))
353
+
354
+ # Initialize target modalities
355
+ for target_mod, ntoks in zip(target_domains, tokens_per_target):
356
+ sample_dict = init_empty_target_modality(sample_dict, MODALITY_INFO, target_mod, 1, ntoks, self.device)
357
+
358
+ return sample_dict, schedule
359
+
360
+ def __setup_sr_sample_and_schedule(self, out_dict, cond_domains_sr, target_domains_sr, cfg_grow_conditioning_sr=True):
361
+ # 1 - Setup generation schedule
362
+
363
+ tokens_per_target_sr = [get_value(DEFAULTS_SR, domain, 'tokens_per_target') for domain in target_domains_sr]
364
+ autoregression_schemes_sr = [get_value(DEFAULTS_SR, domain, 'autoregression_scheme') for domain in target_domains_sr]
365
+ decoding_steps_sr = [get_value(DEFAULTS_SR, domain, 'decoding_steps') for domain in target_domains_sr]
366
+ token_decoding_schedules_sr = [get_value(DEFAULTS_SR, domain, 'token_decoding_schedule') for domain in target_domains_sr]
367
+ temps_sr = [get_value(DEFAULTS_SR, domain, 'temp') for domain in target_domains_sr]
368
+ temp_schedules_sr = [get_value(DEFAULTS_SR, domain, 'temp_schedule') for domain in target_domains_sr]
369
+ cfg_scales_sr = [get_value(DEFAULTS_SR, domain, 'cfg_scale') for domain in target_domains_sr]
370
+ cfg_schedules_sr = [get_value(DEFAULTS_SR, domain, 'cfg_schedule') for domain in target_domains_sr]
371
+
372
+ schedule_sr = build_chained_generation_schedules(
373
+ cond_domains=cond_domains_sr, target_domains=target_domains_sr, tokens_per_target=tokens_per_target_sr,
374
+ autoregression_schemes=autoregression_schemes_sr, decoding_steps=decoding_steps_sr,
375
+ token_decoding_schedules=token_decoding_schedules_sr, temps=temps_sr, temp_schedules=temp_schedules_sr,
376
+ cfg_scales=cfg_scales_sr, cfg_schedules=cfg_schedules_sr, cfg_grow_conditioning=cfg_grow_conditioning_sr,
377
+ )
378
+
379
+ # 2 - Setup sample
380
+
381
+ sample_sr = out_dict
382
+
383
+ # Handle case where generated caption or bounding boxes is just [EOS]
384
+ if 'caption' in sample_sr and sample_sr['caption']['tensor'].shape[1] <= 1 and 'caption' in cond_domains_sr:
385
+ sample_sr = custom_text(
386
+ sample_sr, input_text='[S_1]', eos_token='[EOS]',
387
+ key='caption', device=self.device, text_tokenizer=self.tok_text
388
+ )
389
+ if 'det' in sample_sr and sample_sr['det']['tensor'].shape[1] <= 1 and 'det' in cond_domains_sr:
390
+ sample_sr = custom_text(
391
+ sample_sr, input_text='[S_1]', eos_token='[EOS]',
392
+ key='det', device=self.device, text_tokenizer=self.tok_text
393
+ )
394
+
395
+ # Initialize input modalities
396
+ for cond_mod in cond_domains_sr:
397
+ sample_sr = init_full_input_modality(sample_sr, MODALITY_INFO, cond_mod, self.device, eos_id=self.tok_text.token_to_id("[EOS]"))
398
+
399
+ # Initialize target modalities
400
+ for target_mod, ntoks in zip(target_domains_sr, tokens_per_target_sr):
401
+ sample_sr = init_empty_target_modality(sample_sr, MODALITY_INFO, target_mod, 1, ntoks, self.device)
402
+
403
+ return sample_sr, schedule_sr
404
+
405
+ def forward(self, sample, seed: Optional[int] = None, top_p: float = 0.8, top_k: float = 0.0, target_modalities: Optional[List[str]] = None, perform_sr: bool = True):
406
+ seed = seed or np.random.randint(np.iinfo(np.int64).max)
407
+
408
+ # Prepare the generation parameters and sample
409
+ cond_domains, target_domains = self.__setup_conds_and_targets(sample)
410
+ target_domains = target_modalities or target_domains
411
+ sample, generation_schedule = self.__setup_sample_and_schedule(sample, cond_domains, target_domains)
412
+
413
+ # Generation and decoding at the base resolution 224x224
414
+ if self.verbose:
415
+ print(f'Generating {cond_domains} -> {target_domains} ...')
416
+ out_dict = self.sampler_fm.generate(
417
+ sample, generation_schedule, text_tokenizer=self.tok_text,
418
+ verbose=self.verbose, seed=seed, top_p=top_p, top_k=top_k,
419
+ )
420
+ dec_dict = decode_dict(
421
+ out_dict, self.toks, self.tok_text, image_size=224,
422
+ patch_size=16, decoding_steps=50
423
+ )
424
+
425
+ # Optional upsampling to 448x448
426
+ if self.sampler_fm_sr is not None and perform_sr:
427
+ cond_domains_sr, target_domains_sr = self.__setup_sr_conds_and_targets(out_dict)
428
+ sample_sr, generation_schedule_sr = self.__setup_sr_sample_and_schedule(out_dict, cond_domains_sr, target_domains_sr)
429
+
430
+ if self.verbose:
431
+ print(f'Super-resolving {target_domains_sr} ...')
432
+ out_dict_sr = self.sampler_fm_sr.generate(
433
+ sample_sr, generation_schedule_sr, text_tokenizer=self.tok_text,
434
+ verbose=self.verbose, seed=seed+1, top_p=top_p, top_k=top_k,
435
+ )
436
+ dec_dict = decode_dict(
437
+ out_dict_sr, self.toks, self.tok_text, image_size=448,
438
+ patch_size=16, decoding_steps=50
439
+ )
440
+
441
+ # Remove padding tokens
442
+ if 'caption' in dec_dict:
443
+ dec_dict['caption'][0].replace('[PAD]', '').strip()
444
+ if 'det' in dec_dict:
445
+ dec_dict['det'][0].replace('[PAD]', '').strip()
446
+
447
+ return dec_dict
448
+
449
+ def plot_modalities(self, mod_dict, ncols_max=5, figscale=4.0, save_path=None, use_fixed_plotting_order=False):
450
+ nmods = len(mod_dict)
451
+ ncols = min(nmods, ncols_max)
452
+ nrows = math.ceil(nmods / ncols)
453
+
454
+ fig, ax = plt.subplots(
455
+ nrows=nrows, ncols=ncols,
456
+ figsize=(ncols*figscale, nrows*figscale),
457
+ facecolor=(1, 1, 1)
458
+ )
459
+
460
+ if use_fixed_plotting_order:
461
+ mod_dict = {
462
+ k: mod_dict[k] for k in MODALITY_PLOTTING_ORDER
463
+ if k in mod_dict
464
+ }
465
+
466
+ for i, (mod_name, mod) in enumerate(mod_dict.items()):
467
+ if nrows == 1:
468
+ ax_i = ax[i]
469
+ else:
470
+ row, col = i // ncols, i % ncols
471
+ ax_i = ax[row,col]
472
+
473
+ if mod_name == 'det':
474
+ # Attempt to get the first available value from mod_dict according to the priority
475
+ keys_in_order = ['rgb@448', 'rgb@224', 'tok_rgb@448', 'tok_rgb@224']
476
+ rgb_background = next((mod_dict[key] for key in keys_in_order if key in mod_dict), np.ones((224, 224, 3)))
477
+ rgb_background = (255 * rgb_background).astype(np.uint8)
478
+ ax_i.imshow(visualize_bboxes(rgb_background, mod[0],).astype(np.uint8))
479
+ elif mod_name == 'caption':
480
+ plot_text_in_square(ax_i, mod[0], wrap_width=16, fontsize=14)
481
+ elif mod_name == 'metadata':
482
+ metadata_pred = ',\n'.join([f'{k}: {v:.2f}' if isinstance(v, float) else f'{k}: {v}' for k, v in mod.items()])
483
+ plot_text_in_square(ax_i, metadata_pred, wrap_width=36, fontsize=13)
484
+ else:
485
+ ax_i.imshow(mod)
486
+
487
+ ax_i.set_title(MODALITY_PLOTTING_NAME_MAP.get(mod_name, mod_name), fontsize=18)
488
+
489
+ for i, axis in enumerate(ax.flatten()):
490
+ axis.set_xticks([])
491
+ axis.set_yticks([])
492
+ if i >= len(mod_dict):
493
+ axis.spines['top'].set_visible(False)
494
+ axis.spines['right'].set_visible(False)
495
+ axis.spines['bottom'].set_visible(False)
496
+ axis.spines['left'].set_visible(False)
497
+
498
+ plt.tight_layout()
499
+ if save_path is not None:
500
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
501
+ plt.savefig(save_path, bbox_inches='tight', dpi=300)
502
+ plt.close()
503
+ else:
504
+ plt.show()
505
+
506
+ def modalities_to_pil(self, mod_dict, use_fixed_plotting_order=False, resize=None):
507
+ if use_fixed_plotting_order:
508
+ mod_dict = {
509
+ k: mod_dict[k] for k in MODALITY_PLOTTING_ORDER
510
+ if k in mod_dict
511
+ }
512
+
513
+ plotted_modalities = []
514
+
515
+ for i, (mod_name, mod) in enumerate(mod_dict.items()):
516
+ if mod_name == 'det':
517
+ # Attempt to get the first available value from mod_dict according to the priority
518
+ keys_in_order = ['rgb@448', 'rgb@224', 'tok_rgb@448', 'tok_rgb@224']
519
+ rgb_background = next((mod_dict[key] for key in keys_in_order if key in mod_dict), np.ones((224, 224, 3)))
520
+ rgb_background = (255 * rgb_background).astype(np.uint8)
521
+ img_pil = Image.fromarray(visualize_bboxes(rgb_background, mod[0],).astype(np.uint8))
522
+ elif mod_name == 'caption':
523
+ img_pil = text_to_pil_image(mod[0][:512], wrap_width=40, fontsize=14)
524
+ elif mod_name == 'metadata':
525
+ metadata_pred = ',\n'.join([f'{k}: {v:.2f}' if isinstance(v, float) else f'{k}: {v}' for k, v in mod.items()])
526
+ img_pil = text_to_pil_image(metadata_pred, wrap_width=36, fontsize=13)
527
+ else:
528
+ img_pil = Image.fromarray((255*mod).astype(np.uint8))
529
+
530
+ if resize is not None:
531
+ if mod_name in ['tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224', 'tok_clip@448']:
532
+ resample_mode = Image.Resampling.NEAREST
533
+ else:
534
+ resample_mode = Image.Resampling.BILINEAR
535
+ img_pil = img_pil.resize((resize, resize), resample=resample_mode)
536
+
537
+ plot_name = MODALITY_PLOTTING_NAME_MAP.get(mod_name, mod_name)
538
+ plotted_modalities.append((img_pil, plot_name))
539
+
540
+ return plotted_modalities
fourm/models/__init__.py ADDED
File without changes
fourm/models/decoder_embeddings.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ from typing import Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from einops import repeat
20
+
21
+ from .fm_utils import build_1d_sincos_posemb, build_2d_sincos_posemb, pair
22
+
23
+
24
+ class SequenceDecoderEmbedding(nn.Module):
25
+ """Embedding module for sequence inputs, like captions or a sequence of objects.
26
+
27
+ Args:
28
+ vocab_size: Vocabulary size
29
+ max_length: Maximum number of tokens in the sequence
30
+ dim_tokens: Dimension of output tokens. Can be set using init method.
31
+ sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings
32
+ padding_idx: Padding index for word embedding
33
+ share_embedding: Set to True to share input and output embedding weights
34
+ """
35
+ def __init__(self,
36
+ vocab_size: int,
37
+ max_length: int,
38
+ dim_tokens: Optional[int] = None,
39
+ sincos_pos_emb: bool = True,
40
+ max_sincos_pos_emb: int = 512,
41
+ padding_idx: int = 0,
42
+ share_embedding: bool = True,
43
+ **kwargs):
44
+ super().__init__()
45
+ self.vocab_size = vocab_size
46
+ self.max_length = max_length
47
+ self.dim_tokens = dim_tokens
48
+ self.sincos_pos_emb = sincos_pos_emb
49
+ self.padding_idx = padding_idx
50
+ self.max_sincos_pos_emb = max_sincos_pos_emb
51
+ self.share_embedding = share_embedding
52
+
53
+ if self.dim_tokens is not None:
54
+ self.init(dim_tokens=dim_tokens)
55
+
56
+ def init(self, dim_tokens: int = 768, init_std=0.02):
57
+ """
58
+ Initialize parts of embedding module that are dependent on dimension of tokens.
59
+ Should be called when setting up FourM.
60
+
61
+ Args:
62
+ dim_tokens: Dimension of tokens
63
+ init_std: Standard deviation of init
64
+ """
65
+ self.dim_tokens = dim_tokens
66
+
67
+ # Task embedding identifying from which task a given token comes from
68
+ # Fixed-size positional embeddings. Can be interpolated to different input sizes
69
+
70
+ if self.sincos_pos_emb:
71
+ if self.max_length > self.max_sincos_pos_emb:
72
+ raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}")
73
+ # Get all posembs, than truncate up to max length
74
+ pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length]
75
+ self.register_buffer("pos_emb", pos_emb)
76
+ else:
77
+ self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens))
78
+ nn.init.normal_(self.pos_emb, std=init_std)
79
+
80
+ self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
81
+ nn.init.normal_(self.mod_emb, std=init_std)
82
+
83
+ # Token embedding
84
+ self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens, padding_idx=self.padding_idx)
85
+
86
+ # Output projection layer
87
+ self.to_logits = nn.Linear(self.dim_tokens, self.vocab_size, bias=False)
88
+
89
+ if self.share_embedding:
90
+ # Share input and output embedding weights
91
+ self.to_logits.weight = self.token_emb.weight
92
+
93
+
94
+ @torch.jit.ignore
95
+ def no_weight_decay(self):
96
+ return set()
97
+
98
+ def forward_embed(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
99
+ """
100
+ Forward pass through embedding module, transforming sequence of ids to sequence of embeddings.
101
+ Creates corresponding modality and positional embeddings and adds them to the dict.
102
+
103
+ Args:
104
+ d (Dict[str, torch.Tensor]): Modality dict, with at least the following keys:
105
+ - 'tensor' (torch.Tensor): Token sequence for each batch. Shape (B, L) where B is the batch size and L is the sequence length.
106
+ - 'target_mask' (torch.Tensor): Mask for valid tokens in the target sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
107
+
108
+ Returns:
109
+ Dict[str, torch.Tensor]: Modality dict with added keys:
110
+ - 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the embedding dimension.
111
+ - 'emb' (torch.Tensor): Sum of positional and modality embeddings for the target sequence. Shape (B, L, D).
112
+ - 'ids' (torch.Tensor): Original token sequence from input dict. Shape (B, L).
113
+ """
114
+ ids = d['tensor']
115
+ B = ids.shape[0]
116
+ assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
117
+
118
+ # Map to embedding
119
+ x = self.token_emb(ids)
120
+
121
+ expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B)
122
+
123
+ # Target pos encoding
124
+ target_mask = d['target_mask']
125
+ target_pos_id = (~target_mask).int().cumsum(dim=1) - 1
126
+ target_pos_id[target_mask] = 0
127
+ # Sometimes target sequence is over max length, it will be truncated in decoder
128
+ target_pos_id[target_pos_id >= self.max_length] = 0
129
+ target_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(target_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2]))
130
+ target_pos_emb[target_mask] = 0
131
+
132
+ x_emb = target_pos_emb + self.mod_emb
133
+
134
+
135
+ d['x'] = x
136
+ d['emb'] = x_emb
137
+ d['ids'] = d['tensor']
138
+
139
+ return d
140
+
141
+ def forward_logits(self, x: torch.Tensor) -> torch.Tensor:
142
+ """
143
+ Forward pass through output projection layer, transforming sequence of embeddings to logits.
144
+
145
+ Args:
146
+ x (torch.Tensor): Output tokens from the decoder. Shape (B, M, D)
147
+
148
+ Returns:
149
+ torch.Tensor: Logits for each token in the sequence. Shape (B, M, V)
150
+ """
151
+ logits = self.to_logits(x)
152
+ return logits
153
+
154
+
155
+
156
+ class ImageTokenDecoderEmbedding(nn.Module):
157
+ """Embedding module for tokenized spatial inputs.
158
+
159
+ Args:
160
+ vocab_size: Vocabulary size
161
+ patch_size: Int or tuple of the patch size over the full image size.
162
+ dim_tokens: Dimension of output tokens. Can be set using init method.
163
+ sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
164
+ image_size: Default image size. Used to initialize size of positional embeddings.
165
+ share_embedding: Set to True to share input and output embedding weights
166
+ """
167
+ def __init__(self,
168
+ vocab_size: int,
169
+ patch_size: Union[int, Tuple[int,int]] = 16,
170
+ dim_tokens: Optional[int] = None,
171
+ sincos_pos_emb: bool = True,
172
+ image_size: Union[int, Tuple[int]] = 224,
173
+ share_embedding: bool = True,
174
+ **kwargs):
175
+ super().__init__()
176
+ self.vocab_size = vocab_size
177
+ self.patch_size = pair(patch_size)
178
+ self.dim_tokens = dim_tokens
179
+ self.sincos_pos_emb = sincos_pos_emb
180
+ self.image_size = pair(image_size)
181
+ self.num_patches = (self.image_size[0] // self.patch_size[0]) * (self.image_size[1] // self.patch_size[1])
182
+ self.share_embedding = share_embedding
183
+
184
+ if self.dim_tokens is not None:
185
+ self.init(dim_tokens=dim_tokens)
186
+
187
+ def init(self, dim_tokens: int = 768, init_std=0.02):
188
+ """
189
+ Initialize parts of module that are dependent on dimension of tokens.
190
+ Should be called when setting up FourM.
191
+
192
+ Args:
193
+ dim_tokens: Dimension of tokens
194
+ init_std: Standard deviation of init
195
+ """
196
+ self.dim_tokens = dim_tokens
197
+
198
+ # Task embedding identifying from which task a given token comes from
199
+ # Fixed-size positional embeddings. Can be interpolated to different input sizes
200
+ h_posemb = self.image_size[0] // self.patch_size[0]
201
+ w_posemb = self.image_size[1] // self.patch_size[1]
202
+ if self.sincos_pos_emb:
203
+ pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens)
204
+ self.register_buffer("pos_emb", pos_emb)
205
+ else:
206
+ self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens))
207
+ nn.init.normal_(self.pos_emb, std=init_std)
208
+
209
+ self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
210
+ nn.init.normal_(self.mod_emb, std=init_std)
211
+
212
+ # Token embedding (not needed if only masked tokens are given as input, but can be useful to train Token Critic)
213
+ self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens)
214
+
215
+ # Output projection layer
216
+ self.to_logits = nn.Linear(self.dim_tokens, self.vocab_size, bias=False)
217
+
218
+ if self.share_embedding:
219
+ # Share input and output embedding weights
220
+ self.to_logits.weight = self.token_emb.weight
221
+
222
+ @torch.jit.ignore
223
+ def no_weight_decay(self):
224
+ return set()
225
+
226
+ def forward_embed(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
227
+ """
228
+ Forward pass through the embedding module, transforming tokenized spatial inputs to embeddings.
229
+ Creates corresponding modality and positional embeddings and adds them to the dict.
230
+
231
+ Args:
232
+ d (Dict[str, torch.Tensor]): Modality dict, with at least the following key:
233
+ - 'tensor' (torch.Tensor): Modality tokens for each batch (e.g. from tokenized images). Shape (B, H, W) where B is the batch size, H and W are height and width after tokenization.
234
+
235
+
236
+ Returns:
237
+ Dict[str, torch.Tensor]: Modality dict with added keys:
238
+ - 'x' (torch.Tensor): Embedded token sequence, which is replaced by mask tokens in the 4M decoder. Shape (B, H*W, D) where D is the embedding dimension.
239
+ - 'emb' (torch.Tensor): Sum of positional and modality embeddings for the token sequence. Shape (B, H*W, D).
240
+ - 'ids' (torch.Tensor): Reshaped token sequence from input dict, flattened in the spatial dimensions. Shape (B, H*W).
241
+ """
242
+ ids = d['tensor']
243
+ B = ids.shape[0]
244
+ ids = ids.reshape(B, -1)
245
+
246
+ # Map to embedding
247
+ x = self.token_emb(ids)
248
+
249
+ # Create positional embedding + modality embedding
250
+ x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B)
251
+
252
+ d['x'] = x
253
+ d['emb'] = x_emb
254
+ d['ids'] = ids
255
+ return d
256
+
257
+ def forward_logits(self, x: torch.Tensor) -> torch.Tensor:
258
+ """
259
+ Forward pass through output projection layer, transforming sequence of embeddings to logits.
260
+
261
+ Args:
262
+ x (torch.Tensor): Output tokens from the decoder. Shape (B, M, D)
263
+
264
+ Returns:
265
+ torch.Tensor: Logits for each token in the sequence. Shape (B, M, V)
266
+ """
267
+ logits = self.to_logits(x)
268
+ return logits
fourm/models/encoder_embeddings.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, List, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from einops import rearrange, repeat
19
+
20
+ from .fm_utils import build_1d_sincos_posemb, build_2d_sincos_posemb, pair
21
+
22
+ class SequenceEncoderEmbedding(nn.Module):
23
+ """Embedding module for encoding sequence inputs, like captions or a sequence of objects.
24
+
25
+ Args:
26
+ vocab_size: Vocabulary size
27
+ max_length: Maximum number of tokens in the sequence
28
+ dim_tokens: Dimension of output tokens. Can be set using init method.
29
+ sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings
30
+ max_sincos_pos_emb: Maximum allowed length for sin-cos positional embeddings
31
+ padding_idx: Padding index for word embedding
32
+ """
33
+
34
+ def __init__(self,
35
+ vocab_size: int,
36
+ max_length: int,
37
+ dim_tokens: Optional[int] = None,
38
+ sincos_pos_emb: bool = True,
39
+ max_sincos_pos_emb: int = 512,
40
+ padding_idx: int = 0,
41
+ ):
42
+ super().__init__()
43
+ self.vocab_size = vocab_size
44
+ self.max_length = max_length
45
+ self.dim_tokens = dim_tokens
46
+ self.sincos_pos_emb = sincos_pos_emb
47
+ self.padding_idx = padding_idx
48
+ self.max_sincos_pos_emb = max_sincos_pos_emb
49
+
50
+ if self.dim_tokens is not None:
51
+ self.init(dim_tokens=dim_tokens)
52
+
53
+ def init(self, dim_tokens: int = 768, init_std=0.02):
54
+ """
55
+ Initialize parts of embedding module that are dependent on dimension of tokens.
56
+ Should be called when setting up FourM.
57
+
58
+ Args:
59
+ dim_tokens: Dimension of tokens
60
+ init_std: Standard deviation of init
61
+ """
62
+ self.dim_tokens = dim_tokens
63
+
64
+ # Task embedding identifying from which task a given token comes from
65
+ # Fixed-size positional embeddings. Can be interpolated to different input sizes
66
+ if self.sincos_pos_emb:
67
+ if self.max_length > self.max_sincos_pos_emb:
68
+ raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}")
69
+ pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length]
70
+ self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
71
+ else:
72
+ self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens))
73
+ nn.init.normal_(self.pos_emb, std=init_std)
74
+
75
+ self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
76
+ nn.init.normal_(self.mod_emb, std=init_std)
77
+
78
+ # Token embedding
79
+ self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens,
80
+ padding_idx=self.padding_idx)
81
+
82
+
83
+ @torch.jit.ignore
84
+ def no_weight_decay(self):
85
+ return set()
86
+
87
+ def forward(self, d : Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
88
+ """
89
+ Forward pass through embedding module, transforming sequence of ids to sequence of embeddings.
90
+ Creates corresponding modality and positional embeddings and adds them to the dict.
91
+
92
+ Args:
93
+ d (Dict[str, torch.Tensor]): Modality dict with at least the following keys:
94
+ - 'tensor' (torch.Tensor): Input token sequence for each batch. Shape (B, L) where B is the batch size and L is the sequence length.
95
+ - 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
96
+
97
+ Returns:
98
+ Dict[str, torch.Tensor]: Modality dict with added keys:
99
+ - 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the embedding dimension.
100
+ - 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, L, D).
101
+ """
102
+ ids = d['tensor']
103
+ B = ids.shape[0]
104
+ assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
105
+
106
+ # Map to embedding
107
+ x = self.token_emb(ids)
108
+
109
+ expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B)
110
+ # Input pos encoding
111
+ input_mask = d['input_mask']
112
+ input_pos_id = (~input_mask).int().cumsum(dim=1) - 1
113
+ input_pos_id[input_mask] = 0
114
+ input_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(input_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2]))
115
+ input_pos_emb[input_mask] = 0
116
+
117
+ x_emb = input_pos_emb + self.mod_emb
118
+
119
+ d['x'] = x
120
+ d['emb'] = x_emb
121
+ return d
122
+
123
+ class ImageTokenEncoderEmbedding(nn.Module):
124
+ """Embedding module for tokenized spatial inputs.
125
+
126
+ Args:
127
+ vocab_size: Vocabulary size
128
+ patch_size: Int or tuple of the patch size over the full image size.
129
+ dim_tokens: Dimension of output tokens. Can be set using init method.
130
+ sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
131
+ image_size: Default image size. Used to initialize size of positional embeddings.
132
+ """
133
+ def __init__(self,
134
+ vocab_size: int,
135
+ patch_size: Union[int, Tuple[int,int]] = 16,
136
+ dim_tokens: Optional[int] = None,
137
+ sincos_pos_emb: bool = True,
138
+ image_size: Union[int, Tuple[int]] = 224,
139
+ **kwargs):
140
+
141
+ super().__init__()
142
+ self.vocab_size = vocab_size
143
+ self.patch_size = pair(patch_size)
144
+ self.dim_tokens = dim_tokens
145
+ self.sincos_pos_emb = sincos_pos_emb
146
+ self.image_size = pair(image_size)
147
+ self.num_patches = (self.image_size[0] // patch_size) * (self.image_size[1] // patch_size)
148
+
149
+ if self.dim_tokens is not None:
150
+ self.init(dim_tokens=dim_tokens)
151
+
152
+ def init(self, dim_tokens: int = 768, init_std=0.02):
153
+ """
154
+ Initialize parts of module that are dependent on dimension of tokens.
155
+ Should be called when setting up FourM.
156
+
157
+ Args:
158
+ dim_tokens: Dimension of tokens
159
+ init_std: Standard deviation of init
160
+ """
161
+ self.dim_tokens = dim_tokens
162
+
163
+ # Task embedding identifying from which task a given token comes from
164
+ # Fixed-size positional embeddings. Can be interpolated to different input sizes
165
+ h_posemb = self.image_size[0] // self.patch_size[0]
166
+ w_posemb = self.image_size[1] // self.patch_size[1]
167
+ if self.sincos_pos_emb:
168
+ pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens)
169
+ self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
170
+ else:
171
+ self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens))
172
+ nn.init.normal_(self.pos_emb, std=init_std)
173
+
174
+ self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
175
+ nn.init.normal_(self.mod_emb, std=init_std)
176
+
177
+ # Token embedding
178
+ self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.dim_tokens)
179
+
180
+ @torch.jit.ignore
181
+ def no_weight_decay(self):
182
+ return set()
183
+
184
+ def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
185
+ """
186
+ Forward pass through embedding module, transforming image tokens to a sequence of embeddings.
187
+ Creates corresponding modality and positional embeddings and adds them to the dict.
188
+
189
+ Args:
190
+ d (Dict[str, torch.Tensor]): Modality dict with at least the following key:
191
+ - 'tensor' (torch.Tensor): Input image tokens for each batch. Shape (B, H, W) where B is the batch size, and H, W are height and width of the tokenized image. - 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
192
+
193
+ Returns:
194
+ Dict[str, torch.Tensor]: Modality dictionary with added keys:
195
+ - 'x' (torch.Tensor): Embedded token sequence. Shape (B, H*W, D).
196
+ - 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, H*W, D).
197
+ """
198
+ ids = d['tensor']
199
+ B = ids.shape[0]
200
+ ids = ids.reshape(B, -1)
201
+
202
+ # Map to embedding
203
+ x = self.token_emb(ids)
204
+
205
+ # Create positional embedding + modality embedding
206
+ x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B)
207
+
208
+ d['x'] = x
209
+ d['emb'] = x_emb
210
+
211
+ return d
212
+
213
+
214
+ class ImageEncoderEmbedding(nn.Module):
215
+ """Embedding module for spatial inputs, like images or feature maps.
216
+ Creates tokens from patches over the image.
217
+
218
+ This adapter / embedding differs from the one of MultiMAE by taking as input a dict and
219
+ separating positional embeddings and modality embeddings from the input projection
220
+ Input projection is 'x', posemb + modemb is 'emb'
221
+
222
+ Args:
223
+ num_channels: Number of input channels of the image/feature map
224
+ patch_size: Int or tuple of the patch size over the full image size.
225
+ dim_tokens: Dimension of output tokens. Can be set using init method.
226
+ sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
227
+ image_size: Default image size. Used to initialize size of positional embeddings.
228
+ """
229
+ def __init__(self,
230
+ num_channels: int,
231
+ patch_size: Union[int, Tuple[int,int]],
232
+ dim_tokens: Optional[int] = None,
233
+ sincos_pos_emb: bool = True,
234
+ image_size: Union[int, Tuple[int]] = 224):
235
+
236
+ super().__init__()
237
+ self.num_channels = num_channels
238
+ self.patch_size = pair(patch_size)
239
+ self.dim_tokens = dim_tokens
240
+ self.sincos_pos_emb = sincos_pos_emb
241
+ self.image_size = pair(image_size)
242
+ self.num_patches = (self.image_size[0] // patch_size) * (self.image_size[1] // patch_size)
243
+
244
+ if self.dim_tokens is not None:
245
+ self.init(dim_tokens=dim_tokens)
246
+
247
+ def init(self, dim_tokens: int = 768, init_std=0.02):
248
+ """
249
+ Initialize parts of encoder that are dependent on dimension of tokens.
250
+ Should be called when setting up FourM.
251
+
252
+ Args:
253
+ dim_tokens: Dimension of tokens
254
+ init_std: Standard deviation of init
255
+ """
256
+ self.dim_tokens = dim_tokens
257
+
258
+ # Task embedding identifying from which task a given token comes from
259
+ # Fixed-size positional embeddings. Can be interpolated to different input sizes
260
+ h_posemb = self.image_size[0] // self.patch_size[0]
261
+ w_posemb = self.image_size[1] // self.patch_size[1]
262
+ if self.sincos_pos_emb:
263
+ pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens)
264
+ self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
265
+ else:
266
+ self.pos_emb = nn.Parameter(torch.zeros(1, (h_posemb * w_posemb), self.dim_tokens))
267
+ nn.init.normal_(self.pos_emb, std=init_std)
268
+
269
+ self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
270
+ nn.init.normal_(self.mod_emb, std=init_std)
271
+
272
+ # Image -> tokens projection
273
+ # No bias term here, so modality embedding fully comes from self.mod_emb
274
+ self.proj = nn.Linear(self.num_channels * self.patch_size[0] * self.patch_size[1], self.dim_tokens, bias=False)
275
+
276
+ @torch.jit.ignore
277
+ def no_weight_decay(self):
278
+ return set()
279
+
280
+ def forward(self, d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
281
+ """
282
+ Forward pass through embedding module, transforming image to sequence of tokens.
283
+ Creates corresponding modality and positional embeddings and adds them to the dict.
284
+
285
+ Args:
286
+ d (Dict[str, torch.Tensor]): Modality dict with at least the following key:
287
+ - 'tensor' (torch.Tensor): Input image for each batch. Shape (B, C, H, W) where B is the batch size, C is the number of channels, and H, W are height and width of the image.
288
+
289
+
290
+ Returns:
291
+ Dict[str, torch.Tensor]: Modality dict with added keys:
292
+ - 'x' (torch.Tensor): Embedded token sequence. Shape (B, (H / PH) * (W / PW), D), where PH and PW are the patch sizes
293
+ - 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, (H / PH) * (W / PW), D)
294
+ """
295
+ x = d['tensor']
296
+ B, C, H, W = x.shape
297
+ assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
298
+ assert (H % self.patch_size[0] == 0) and (W % self.patch_size[1] == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.patch_size[0]}x{self.patch_size[1]}'
299
+
300
+ # Create patches [B, C, H, W] -> [B, (H*W), C]
301
+ x_patch = self.proj(rearrange(x, 'b d (nh ph) (nw pw) -> b (nh nw) (ph pw d)', ph=self.patch_size[0], pw=self.patch_size[1]))
302
+
303
+ # Create positional embedding + modality embedding
304
+ x_emb = repeat(self.pos_emb + self.mod_emb, '() n d -> b n d', b=B)
305
+
306
+ d['x'] = x_patch
307
+ d['emb'] = x_emb
308
+
309
+ return d
310
+
311
+
312
+ class SequenceEmbEncoderEmbedding(nn.Module):
313
+ """Adapter for sequence emb inputs, like T5-XXL, CLIP text embeddings.
314
+
315
+ Args:
316
+ max_length: Maximum number of tokens in the sequence
317
+ dim_tokens: Dimension of output tokens. Can be set using init method.
318
+ sincos_pos_emb: Set to True (default) to use fixed 1D sin-cos positional embeddings
319
+ padding_idx: Padding index for word embedding
320
+ orig_emb_dim: Dimension of original embeddings
321
+ bottleneck_dim: Dimension of bottleneck layer
322
+ use_bottleneck: Set to True to use bottleneck layer
323
+ """
324
+ def __init__(self,
325
+ max_length: int,
326
+ dim_tokens: Optional[int] = None,
327
+ sincos_pos_emb: bool = True,
328
+ max_sincos_pos_emb: int = 512,
329
+ padding_idx: int = 0,
330
+ orig_emb_dim: int = 4096,
331
+ bottleneck_dim: int = 64,
332
+ use_bottleneck: bool = False,
333
+ ):
334
+ super().__init__()
335
+ self.max_length = max_length
336
+ self.dim_tokens = dim_tokens
337
+ self.sincos_pos_emb = sincos_pos_emb
338
+ self.padding_idx = padding_idx
339
+ self.max_sincos_pos_emb = max_sincos_pos_emb
340
+ self.orig_emb_dim = orig_emb_dim
341
+ self.use_bottleneck = use_bottleneck
342
+ if self.use_bottleneck:
343
+ self.bottleneck_dim = bottleneck_dim
344
+
345
+ if self.dim_tokens is not None:
346
+ self.init(dim_tokens=dim_tokens)
347
+
348
+ def init(self, dim_tokens: int = 768, init_std=0.02):
349
+ """
350
+ Initialize parts of embedding module that are dependent on dimension of tokens.
351
+ Should be called when setting up FourM.
352
+
353
+ Args:
354
+ dim_tokens: Dimension of tokens
355
+ init_std: Standard deviation of init
356
+ """
357
+ self.dim_tokens = dim_tokens
358
+
359
+ # Task embedding identifying from which task a given token comes from
360
+ # Fixed-size positional embeddings. Can be interpolated to different input sizes
361
+ if self.sincos_pos_emb:
362
+ if self.max_length > self.max_sincos_pos_emb:
363
+ raise ValueError(f"Max length ({self.max_length}) is greater than the number of posembs ({self.max_sincos_pos_emb}")
364
+ pos_emb = build_1d_sincos_posemb(max_len=self.max_sincos_pos_emb, embed_dim=self.dim_tokens)[:self.max_length]
365
+ self.register_buffer("pos_emb", pos_emb) # self.pos_emb is now a buffer for FSDP
366
+ else:
367
+ self.pos_emb = nn.Parameter(torch.zeros(1, self.max_length, self.dim_tokens))
368
+ nn.init.normal_(self.pos_emb, std=init_std)
369
+
370
+ self.mod_emb = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
371
+ nn.init.normal_(self.mod_emb, std=init_std)
372
+
373
+ # Token embedding projection
374
+ if self.use_bottleneck:
375
+ self.emb_proj = nn.Sequential(
376
+ nn.Linear(self.orig_emb_dim, self.bottleneck_dim),
377
+ nn.Linear(self.bottleneck_dim, self.dim_tokens),
378
+ )
379
+ else:
380
+ self.emb_proj = nn.Linear(self.orig_emb_dim, self.dim_tokens)
381
+
382
+
383
+ @torch.jit.ignore
384
+ def no_weight_decay(self):
385
+ return set()
386
+
387
+ def forward(self, d):
388
+ """
389
+ Forward pass through embedding module, projecting original embeddings to the Transformer dimension.
390
+ Creates corresponding modality and positional embeddings and adds them to the dict.
391
+
392
+ Args:
393
+ d (Dict[str, torch.Tensor]): Modality dict with at least the following keys:
394
+ - 'tensor' (torch.Tensor): Input token sequence for each batch. Shape (B, L, E) where B is the batch size and L is the sequence length, and E is the dimension of the original embeddings.
395
+ - 'input_mask' (torch.Tensor): Mask for valid tokens in the input sequence (set to 0 for valid tokens and 1 otherwise). Shape (B, L).
396
+
397
+ Returns:
398
+ Dict[str, torch.Tensor]: Modality dict with added keys:
399
+ - 'x' (torch.Tensor): Embedded token sequence. Shape (B, L, D) where D is the Transformer embedding dimension.
400
+ - 'emb' (torch.Tensor): Sum of positional and modality embeddings for the input sequence. Shape (B, L, D).
401
+ """
402
+ orig_emb = d['tensor']
403
+ B = orig_emb.shape[0]
404
+ assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first'
405
+
406
+ # Map to embedding
407
+ x = self.emb_proj(orig_emb)
408
+
409
+ expanded_pos_emb = repeat(self.pos_emb, "() n d -> b n d", b=B)
410
+ # Input pos encoding
411
+ input_mask = d['input_mask']
412
+ input_pos_id = (~input_mask).int().cumsum(dim=1) - 1
413
+ input_pos_id[input_mask] = 0
414
+ input_pos_emb = torch.gather(expanded_pos_emb, dim=1, index=repeat(input_pos_id, "b n -> b n d", d=expanded_pos_emb.shape[2]))
415
+ input_pos_emb[input_mask] = 0
416
+
417
+ x_emb = input_pos_emb + self.mod_emb
418
+
419
+ d['x'] = x
420
+ d['emb'] = x_emb
421
+ return d
422
+
fourm/models/fm.py ADDED
@@ -0,0 +1,1130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import random
16
+ import copy
17
+ from functools import partial
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from einops import rearrange, repeat
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+
25
+ from fourm.utils.timm.registry import register_model
26
+ from huggingface_hub import PyTorchModelHubMixin
27
+
28
+ from .fm_utils import Block, DecoderBlock, LayerNorm
29
+ from fourm.data.modality_info import MODALITY_INFO
30
+
31
+
32
+ # Model definitions
33
+ __all__ = [
34
+ # GELU models
35
+ 'fm_tiny_6e_6d_gelu',
36
+ 'fm_small_8e_8d_gelu',
37
+ 'fm_base_12e_12d_gelu',
38
+ 'fm_large_24e_24d_gelu',
39
+ 'fm_xlarge_24e_24d_gelu',
40
+ # SwiGLU models
41
+ 'fm_tiny_6e_6d_swiglu_nobias',
42
+ 'fm_small_8e_8d_swiglu_nobias',
43
+ 'fm_base_12e_12d_swiglu_nobias',
44
+ 'fm_large_24e_24d_swiglu_nobias',
45
+ 'fm_xlarge_24e_24d_swiglu_nobias',
46
+ # SwiGLU + QKNorm models
47
+ 'fm_base_12e_12d_swiglu_qknorm_nobias',
48
+ 'fm_large_24e_24d_swiglu_qknorm_nobias',
49
+ 'fm_xlarge_24e_24d_swiglu_qknorm_nobias',
50
+ ]
51
+
52
+
53
+
54
+ class FourM(nn.Module):
55
+ """4M model.
56
+
57
+ Args:
58
+ encoder_embeddings: Dict of encoder embedding modules.
59
+ decoder_embeddings: Dict of decoder embedding modules.
60
+ modality_info: Dict containing modality information.
61
+ dim: Embedding dimension.
62
+ encoder_depth: Number of encoder blocks.
63
+ decoder_depth: Number of decoder blocks.
64
+ num_heads: Number of attention heads.
65
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
66
+ qkv_bias: If True, add a learnable bias to query, key, value projections.
67
+ proj_bias: If True, add a learnable bias to the last projection of the attention block.
68
+ mlp_bias: If True, add a learnable bias to linear layers in the MLP / feed-forward.
69
+ drop_path_rate_encoder: Stochastic depth rate for encoder.
70
+ drop_path_rate_decoder: Stochastic depth rate for decoder.
71
+ shared_drop_path: If True, shares drop path between encoder and decoder.
72
+ act_layer: Activation layer to be used.
73
+ norm_layer: Normalization layer to be used.
74
+ gated_mlp: If True, make the feedforward gated (e.g., SwiGLU).
75
+ qk_norm: If True, applies normalization to queries and keys (QKNorm).
76
+ decoder_causal_mask: If True, decoder will use a causal mask for all tokens.
77
+ decoder_sep_mask: If True, decoder attention is restricted to within each modality only.
78
+ num_register_tokens: Number of register tokens.
79
+ use_act_checkpoint: If True, use activation checkpoint for each block.
80
+ """
81
+ def __init__(self,
82
+ encoder_embeddings: Dict[str, nn.Module],
83
+ decoder_embeddings: Dict[str, nn.Module],
84
+ modality_info: Dict[str, Any],
85
+ dim: int = 768,
86
+ encoder_depth: int = 12,
87
+ decoder_depth: int = 12,
88
+ num_heads: int = 12,
89
+ mlp_ratio: float = 4.0,
90
+ qkv_bias: bool = True,
91
+ proj_bias: bool = True,
92
+ mlp_bias: bool = True,
93
+ drop_path_rate_encoder: float = 0.0,
94
+ drop_path_rate_decoder: float = 0.0,
95
+ shared_drop_path: bool = False,
96
+ act_layer: nn.Module = nn.GELU,
97
+ norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6),
98
+ gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU
99
+ qk_norm: bool = False,
100
+ decoder_causal_mask: bool = False,
101
+ decoder_sep_mask: bool = True,
102
+ num_register_tokens: int = 0,
103
+ use_act_checkpoint: bool = False,
104
+ share_modality_embeddings: bool = True,
105
+ ):
106
+ super().__init__()
107
+
108
+ self.modality_info = modality_info
109
+ self.dim = dim
110
+ self.decoder_causal_mask = decoder_causal_mask
111
+ self.decoder_sep_mask = decoder_sep_mask
112
+ self.init_std = 0.02
113
+ self.use_act_checkpoint = use_act_checkpoint
114
+ self.num_register_tokens = num_register_tokens
115
+
116
+
117
+ # Encoder embeddings & init
118
+ self.encoder_modalities = set(encoder_embeddings.keys())
119
+ for emb in encoder_embeddings.values():
120
+ emb.init(dim_tokens=dim, init_std=self.init_std)
121
+ self.encoder_embeddings = nn.ModuleDict(encoder_embeddings)
122
+
123
+ # Decoder embeddings & init
124
+ self.decoder_modalities = set(decoder_embeddings.keys())
125
+ for emb in decoder_embeddings.values():
126
+ emb.init(dim_tokens=dim, init_std=self.init_std)
127
+ self.decoder_embeddings = nn.ModuleDict(decoder_embeddings)
128
+
129
+ # Share modality embeddings across the encoder and decoder embedding modules
130
+ if share_modality_embeddings:
131
+ self.share_modality_embeddings()
132
+
133
+ ## Transformer encoder
134
+ if shared_drop_path:
135
+ dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth + decoder_depth)][:encoder_depth]
136
+ else:
137
+ dpr_encoder = [x.item() for x in torch.linspace(0, drop_path_rate_encoder, encoder_depth)] # stochastic depth decay rule
138
+
139
+ self.encoder = nn.ModuleList([
140
+ Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
141
+ drop_path=dpr_encoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm)
142
+ for i in range(encoder_depth)
143
+ ])
144
+ self.encoder_norm = norm_layer(dim)
145
+
146
+
147
+ ## Transformer decoder
148
+ if shared_drop_path:
149
+ dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, encoder_depth + decoder_depth)][encoder_depth:]
150
+ else:
151
+ dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate_decoder, decoder_depth)] # stochastic depth decay rule
152
+
153
+ # Projection of encoder tokens before adding the embeddings again
154
+ self.decoder_proj_context = nn.Linear(dim, dim)
155
+
156
+ self.decoder = nn.ModuleList([
157
+ DecoderBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
158
+ drop_path=dpr_decoder[i], act_layer=act_layer, norm_layer=norm_layer, gated_mlp=gated_mlp, qk_norm=qk_norm)
159
+ for i in range(decoder_depth)
160
+ ])
161
+ self.decoder_norm = norm_layer(dim)
162
+
163
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dim))
164
+ nn.init.normal_(self.mask_token, std=self.init_std)
165
+
166
+ # Additional register tokens that can be used by the encoder during fine-tuning
167
+ if self.num_register_tokens > 0:
168
+ self.register_tokens = nn.Parameter(torch.zeros(1, self.num_register_tokens, dim))
169
+ nn.init.normal_(self.register_tokens, std=self.init_std)
170
+ else:
171
+ self.register_tokens = None
172
+
173
+ # Weight init
174
+ self.init_weights()
175
+
176
+ def share_modality_embeddings(self):
177
+ """Share modality embeddings across the encoder and decoder embedding modules."""
178
+ shared_modalities = self.encoder_modalities & self.decoder_modalities
179
+ for mod in shared_modalities:
180
+ self.decoder_embeddings[mod].mod_emb = self.encoder_embeddings[mod].mod_emb
181
+
182
+ def init_weights(self):
183
+ """Weight initialization following MAE's initialization scheme"""
184
+
185
+ for name, m in self.named_modules():
186
+ # Skipping tokenizers to avoid reinitializing them
187
+ if "tokenizer" in name:
188
+ continue
189
+ # Linear
190
+ elif isinstance(m, nn.Linear):
191
+ if 'qkv' in name:
192
+ # treat the weights of Q, K, V separately
193
+ val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
194
+ nn.init.uniform_(m.weight, -val, val)
195
+ elif 'kv' in name:
196
+ # treat the weights of K, V separately
197
+ val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
198
+ nn.init.uniform_(m.weight, -val, val)
199
+ else:
200
+ nn.init.xavier_uniform_(m.weight)
201
+ if isinstance(m, nn.Linear) and m.bias is not None:
202
+ nn.init.constant_(m.bias, 0)
203
+ # LayerNorm
204
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
205
+ nn.init.constant_(m.weight, 1.0)
206
+ if m.bias is not None:
207
+ nn.init.constant_(m.bias, 0)
208
+ # Embedding
209
+ elif isinstance(m, nn.Embedding):
210
+ nn.init.normal_(m.weight, std=self.init_std)
211
+ # Conv2d
212
+ elif isinstance(m, nn.Conv2d):
213
+ if '.proj' in name:
214
+ # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
215
+ w = m.weight.data
216
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
217
+
218
+ def get_num_layers_encoder(self):
219
+ return len(self.encoder)
220
+
221
+ def get_num_layers_decoder(self):
222
+ return len(self.decoder)
223
+
224
+ def get_num_layers(self):
225
+ return self.get_num_layers_encoder() + self.get_num_layers_decoder()
226
+
227
+ @torch.jit.ignore
228
+ def no_weight_decay(self):
229
+ no_wd_set = set()
230
+
231
+ for mod, emb_module in self.encoder_embeddings.items():
232
+ if hasattr(emb_module, 'no_weight_decay'):
233
+ to_skip = emb_module.no_weight_decay()
234
+ to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip])
235
+ no_wd_set = no_wd_set | to_skip
236
+
237
+ for mod, emb_module in self.decoder_embeddings.items():
238
+ if hasattr(emb_module, 'no_weight_decay'):
239
+ to_skip = emb_module.no_weight_decay()
240
+ to_skip = set([f'decoder_embeddings.{mod}.{name}' for name in to_skip])
241
+ no_wd_set = no_wd_set | to_skip
242
+
243
+ return no_wd_set
244
+
245
+ def cat_encoder_tensors(self, mod_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor]:
246
+ """Concatenate encoder tensors from different modalities.
247
+
248
+ Args:
249
+ mod_dict (dict): A dictionary containing information for each modality.
250
+ Expected keys for each modality are 'x' (input tokens),
251
+ 'emb' (embeddings), 'input_mask', etc.
252
+
253
+ Returns:
254
+ tuple:
255
+ - encoder_tokens_all (torch.Tensor): Concatenated encoder tokens from all modalities. Shape (B, O, D) where O is the total number of all encoder tokens.
256
+ - emb_all (torch.Tensor): Concatenated encoder embeddings from all modalities. Shape (B, O, D)
257
+ - encoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the encoder input (set to 0 for valid tokens, 1 otherwise). Shape (B, O)
258
+ - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each encoder token. Shape (B, O)
259
+ """
260
+
261
+ encoder_tokens_all = []
262
+ emb_all = []
263
+ encoder_mask_all = []
264
+ mod_mask_all = []
265
+
266
+ for mod, d in mod_dict.items():
267
+ encoder_tokens_all.append(d['x'])
268
+ emb_all.append(d['emb'])
269
+ encoder_mask_all.append(d['input_mask'])
270
+ mod_mask_all.append(torch.full_like(d['input_mask'], self.modality_info[mod]['id'], dtype=torch.int16))
271
+
272
+ encoder_tokens_all = torch.cat(encoder_tokens_all, dim=1)
273
+ emb_all = torch.cat(emb_all, dim=1)
274
+ encoder_mask_all = torch.cat(encoder_mask_all, dim=1)
275
+ mod_mask_all = torch.cat(mod_mask_all, dim=1)
276
+
277
+ return encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all
278
+
279
+ def cat_decoder_tensors(self, mod_dict: Dict[str, Dict[str, torch.Tensor]]) -> Tuple[torch.Tensor]:
280
+ """Concatenate decoder tensors from different modalities.
281
+
282
+ Args:
283
+ mod_dict (dict): A dictionary containing information for each modality.
284
+ Expected keys for each modality include 'x' (input tokens),
285
+ 'ids' (target IDs), 'emb' (embeddings), 'target_mask', 'decoder_attention_mask', etc.
286
+
287
+
288
+ Returns:
289
+ tuple:
290
+ - decoder_tokens_all (torch.Tensor): Concatenated decoder tokens from all modalities. Shape (B, P, D) where P is the total number of all decoder tokens.
291
+ - emb_all (torch.Tensor): Concatenated decoder embeddings from all modalities. Shape (B, P, D)
292
+ - decoder_mask_all (torch.Tensor): Concatenated boolean masks indicating which tokens are part of the decoder input / target (set to 0 for valid tokens, 1 otherwise). Shape (B, P)
293
+ - target_ids_all (torch.Tensor): Concatenated target IDs from all modalities. Shape (B, P)
294
+ - attention_mask_all (torch.Tensor): Concatenated attention masks in compressed format, needs to be passed to adapt_decoder_attention_mask() to obtain the final attention mask. Shape (B, P)
295
+ - mod_mask_all (torch.Tensor): Concatenated integer mask marking the modality type for each decoder token. Shape (B, P)
296
+ """
297
+
298
+ decoder_tokens_all = []
299
+ target_ids_all = []
300
+ emb_all = []
301
+ decoder_mask_all = []
302
+ attention_mask_all = []
303
+ mod_mask_all = []
304
+
305
+ # Shuffle order in which modalities are provided (useful for modality causal mask)
306
+ mod_dict = {mod: d for mod, d in random.sample(mod_dict.items(), len(mod_dict))}
307
+
308
+ for mod, d in mod_dict.items():
309
+ if self.modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']:
310
+ # Important: This makes the assumption that the target sequence appears sequentially
311
+ # before sorting / gathering
312
+ decoder_tokens_all.append(d['x'][:, :-1])
313
+ target_ids_all.append(d['ids'][:, 1:]) # Shifted left
314
+ emb_all.append(d['emb'][:, :-1])
315
+ # Logical or with left shifting removes the last unmasked position
316
+ decoder_mask_all.append(torch.logical_or(d['target_mask'][:, 1:], d['target_mask'][:, :-1]))
317
+ # Add attention mask ids
318
+ attention_mask_all.append(d['decoder_attention_mask'][:, :-1])
319
+ mod_mask_all.append(torch.full_like(d['ids'][:, :-1], self.modality_info[mod]['id'], dtype=torch.int16))
320
+ else:
321
+ # Important: For 2d / image modalities, the decoder input tokens are replaced by the mask token
322
+ decoder_tokens_all.append(torch.zeros_like(d['x']) + self.mask_token) # Replace x by mask token
323
+ target_ids_all.append(d['ids'])
324
+ emb_all.append(d['emb'])
325
+ decoder_mask_all.append(d['target_mask'])
326
+ attention_mask_all.append(d['decoder_attention_mask'])
327
+ mod_mask_all.append(torch.full_like(d['ids'], self.modality_info[mod]['id'], dtype=torch.int16))
328
+
329
+ decoder_tokens_all = torch.cat(decoder_tokens_all, dim=1)
330
+ emb_all = torch.cat(emb_all, dim=1)
331
+ decoder_mask_all = torch.cat(decoder_mask_all, dim=1)
332
+ target_ids_all = torch.cat(target_ids_all, dim=1)
333
+ attention_mask_all = torch.cat(attention_mask_all, dim=1)
334
+ mod_mask_all = torch.cat(mod_mask_all, dim=1)
335
+
336
+ return decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, attention_mask_all, mod_mask_all
337
+
338
+ def forward_mask_encoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_encoder_tokens: int) -> Tuple[torch.Tensor]:
339
+ """Concatenates and mask encoder tensors based on provided modality information.
340
+
341
+ This function consolidates encoder tokens from multiple modalities, then selects a specified number of them based on modality information (i.e. masking).
342
+
343
+ Args:
344
+ mod_dict (dict): Dictionary containing tensors for different modalities.
345
+ It is expected to have keys for each modality and values
346
+ containing the modalities' associated tensors.
347
+ num_encoder_tokens (int): Number of encoder tokens to retain after masking.
348
+
349
+ Returns:
350
+ tuple:
351
+ - encoder_tokens (torch.Tensor): Selected encoder tokens from all modalities. Shape (B, N, D) where N is the number of selected encoder tokens.
352
+ - encoder_emb (torch.Tensor): Corresponding embeddings for encoder tokens. Shape (B, N, D)
353
+ - encoder_mask (torch.Tensor): A boolean mask indicating which encoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N)
354
+ - mod_mask (torch.Tensor): An integer mask marking the modality type for each encoder token (with -1 indicating unassigned pad tokens). Shape (B, N)
355
+
356
+ Notes:
357
+ - If `num_register_tokens` is set and greater than 0, register tokens are added at the beginning of the sequence.
358
+ """
359
+ B = list(mod_dict.values())[0]['tensor'].shape[0]
360
+
361
+ encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.cat_encoder_tensors(mod_dict)
362
+
363
+ # Add arange multiplied by small constant to mask so they get sorted in a deterministic way
364
+ mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6
365
+ ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1)
366
+ # ids_restore = torch.argsort(ids_shuffle, dim=1)
367
+ ids_keep = ids_shuffle[:, :num_encoder_tokens]
368
+
369
+ encoder_tokens = torch.gather(encoder_tokens_all, dim=1,
370
+ index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2]))
371
+ encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
372
+ encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep)
373
+ mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
374
+
375
+ if self.num_register_tokens > 0:
376
+ register_tokens = repeat(self.register_tokens, '() n d -> b n d', b=B)
377
+ # We add register tokens at the beginning of the sequence
378
+ encoder_tokens = torch.cat([register_tokens, encoder_tokens], dim=1)
379
+ encoder_emb = torch.cat([torch.zeros_like(register_tokens), encoder_emb], dim=1)
380
+ encoder_mask = torch.cat([torch.zeros((B, register_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1)
381
+ mod_mask = torch.cat([torch.full((B, register_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1)
382
+
383
+ encoder_tokens[encoder_mask] = 0.
384
+ encoder_emb[encoder_mask] = 0.
385
+ mod_mask[encoder_mask] = -1
386
+ # Mask could be of shape 'b n1 n2' but not needed for masked_fill
387
+ # This means this mask can then be re-used for decoder cross-attention
388
+ encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2')
389
+
390
+ return encoder_tokens, encoder_emb, encoder_mask, mod_mask
391
+
392
+ def forward_mask_decoder(self, mod_dict: Dict[str, Dict[str, torch.Tensor]], num_decoder_tokens: int) -> Tuple[torch.Tensor]:
393
+ """Concatenates and mask decoder tensors based on provided modality information.
394
+
395
+ This function consolidates decoder tokens from multiple modalities, selects a specified number of them based on modality information, and applies appropriate masking.
396
+
397
+ Args:
398
+ mod_dict (dict): Dictionary containing tensors for different modalities.
399
+ It is expected to have keys for each modality and values
400
+ containing the modalities' associated tensors.
401
+ num_decoder_tokens (int): Number of decoder tokens to retain after masking.
402
+
403
+ Returns:
404
+ tuple:
405
+ - decoder_tokens (torch.Tensor): Selected decoder tokens from all modalities. Shape (B, M, D) where M is the number of selected decoder tokens.
406
+ - decoder_emb (torch.Tensor): Corresponding embeddings for decoder tokens. Shape (B, M, D)
407
+ - decoder_mask (torch.Tensor): A boolean mask indicating which decoder tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, M)
408
+ - target_ids (torch.Tensor): IDs of the target tokens corresponding to the decoder tokens. Shape (B, M)
409
+ - decoder_attention_mask (torch.Tensor): Mask for the decoder self-attention layers. Shape (B, M, M)
410
+ - mod_mask (torch.Tensor): An integer mask marking the modality type for each decoder token (with -1 indicating unassigned pad tokens). Shape (B, M)
411
+ """
412
+ # decoder_mask and target_mask are equivalent, we rename it here to harmonize with forward_mask_encoder
413
+ decoder_tokens_all, emb_all, decoder_mask_all, target_ids_all, decoder_attention_mask_all, mod_mask_all = self.cat_decoder_tensors(mod_dict)
414
+
415
+ # Add arange multiplied by small constant to mask so they get sorted in a deterministic way
416
+ mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
417
+ ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1)
418
+ # ids_restore = torch.argsort(ids_shuffle, dim=1)
419
+ ids_keep = ids_shuffle[:, :num_decoder_tokens]
420
+
421
+ decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2]))
422
+ decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
423
+ decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
424
+ target_ids = torch.gather(target_ids_all, dim=1, index=ids_keep)
425
+ decoder_attention_mask = torch.gather(decoder_attention_mask_all, dim=1, index=ids_keep)
426
+ mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
427
+
428
+ decoder_tokens[decoder_mask] = 0.
429
+ decoder_emb[decoder_mask] = 0.
430
+ target_ids[decoder_mask] = 0
431
+ decoder_attention_mask = self.adapt_decoder_attention_mask(decoder_attention_mask, mod_mask)
432
+ mod_mask[decoder_mask] = -1
433
+
434
+ # This means this mask can then be re-used for decoder cross-attention
435
+ decoder_mask = rearrange(decoder_mask, 'b n2 -> b 1 n2')
436
+
437
+
438
+ return decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, mod_mask
439
+
440
+ def adapt_decoder_attention_mask(self, decoder_attention_mask: torch.Tensor, mod_mask=Optional[torch.Tensor]) -> torch.Tensor:
441
+ """
442
+ Transforms the compressed decoder attention mask to a full attention mask based on the specified constraints.
443
+
444
+ Args:
445
+ decoder_attention_mask (torch.Tensor): Initial attention mask indicating attention constraints. Shape (B, M) where M is the number of the decoder tokens.
446
+ mod_mask (torch.Tensor, optional): Modality mask to separate attention masks per modality. Shape (B, M)
447
+
448
+ Returns:
449
+ torch.Tensor: Adapted attention mask. Shape (B, M, M) where M is the number of the decoder tokens.
450
+ """
451
+ B, N = decoder_attention_mask.shape
452
+
453
+ if self.decoder_causal_mask:
454
+ # For causal mode, tokens can only attend to preceding tokens and themselves.
455
+ causal_mask = torch.ones((N, N), dtype=torch.bool, device=decoder_attention_mask.device).triu(1)
456
+ causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B)
457
+ adapted_attention_mask = causal_mask
458
+ else:
459
+ # Cumulatively sum the attention mask to determine token-wise attention behavior.
460
+ # Examples:
461
+ # Mask [4, 0, 0, 0] -> Cumsum: [4, 4, 4, 4] -> All tokens attend to each other.
462
+ # Mask [1, 1, 1, 1] -> Cumsum: [1, 2, 3, 4] -> Strict autoregressive behavior.
463
+ # Mask [2, 0, 1, 1] -> Cumsum: [2, 2, 3, 4] -> Tokens 1 and 2 attend to each other, token 3 attends to tokens 1-3, and token 4 to all.
464
+ attention_arange = torch.arange(N, device=decoder_attention_mask.device)
465
+ attention_arange = repeat(attention_arange, "n2 -> b n1 n2", b=B, n1=N)
466
+ cumsum_mask = torch.cumsum(decoder_attention_mask, dim=-1)
467
+ cumsum_mask = rearrange(cumsum_mask, "b n -> b n 1")
468
+ adapted_attention_mask = (attention_arange >= cumsum_mask)
469
+
470
+ if self.decoder_sep_mask:
471
+ # Separate attention between tokens based on their modality using mod_mask.
472
+ sep_mask = repeat(mod_mask, "b n2 -> b n1 n2", n1=N) != repeat(mod_mask, "b n1 -> b n1 n2", n2=N)
473
+ adapted_attention_mask = adapted_attention_mask | sep_mask
474
+
475
+ return adapted_attention_mask
476
+
477
+ def forward_encoder(self,
478
+ x: torch.Tensor,
479
+ encoder_mask: torch.Tensor) -> torch.Tensor:
480
+ """Forward pass for the encoder.
481
+
482
+ Args:
483
+ x (torch.Tensor): Encoder input tokens. Shape (B, N, D) where N is the number of encoder tokens.
484
+ encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N)
485
+
486
+ Returns:
487
+ torch.Tensor: Encoder output. Shape (B, N, D)
488
+ """
489
+
490
+ for blk in self.encoder:
491
+ x = blk(x, mask=encoder_mask)
492
+
493
+ x = self.encoder_norm(x)
494
+
495
+ return x
496
+
497
+ def forward_decoder(self,
498
+ y: torch.Tensor,
499
+ context: torch.Tensor,
500
+ encoder_mask: torch.Tensor,
501
+ decoder_attention_mask: torch.Tensor) -> torch.Tensor:
502
+ """Forward pass for the decoder.
503
+
504
+ Args:
505
+ y (torch.Tensor): Decoder input tokens. Shape (B, M, D).
506
+ context (torch.Tensor): Context for the decoder (i.e. encoder output). Shape (B, N, D).
507
+ encoder_mask (torch.Tensor): Encoder mask indicating which tokens are valid (set to 0 for valid tokens, 1 otherwise). Shape (B, 1, N).
508
+ decoder_attention_mask (torch.Tensor): Decoder attention mask. Shape (B, M, M).
509
+
510
+ Returns:
511
+ torch.Tensor: Decoder output. Shape (B, M, D).
512
+ """
513
+
514
+ for blk in self.decoder:
515
+ y = blk(y, context, sa_mask=decoder_attention_mask, xa_mask=encoder_mask)
516
+
517
+ y = self.decoder_norm(y)
518
+
519
+ return y
520
+
521
+ def forward_logits(self,
522
+ y: torch.Tensor,
523
+ decoder_mod_dict: Dict[str, Dict[str, torch.Tensor]],
524
+ decoder_mod_mask: torch.Tensor,
525
+ return_all_logits: bool = False) -> Dict[str, torch.Tensor]:
526
+ """Forward computation of logits for each modality.
527
+
528
+ Args:
529
+ y (torch.Tensor): Decoder output. Shape (B, M, D).
530
+ decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
531
+ decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M).
532
+
533
+ Returns:
534
+ Dict[str, torch.Tensor]: Dictionary of logits for each modality.
535
+ """
536
+
537
+ mod_logits = {}
538
+ for mod, d in decoder_mod_dict.items():
539
+ idx = self.modality_info[mod]["id"]
540
+ if return_all_logits:
541
+ logits = self.decoder_embeddings[mod].forward_logits(y)
542
+ else:
543
+ logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
544
+ mod_logits[mod] = logits
545
+ return mod_logits
546
+
547
+ def forward_loss(self,
548
+ y: torch.Tensor,
549
+ target_ids: torch.Tensor,
550
+ decoder_mod_dict: Dict[str, Any],
551
+ decoder_mod_mask: torch.Tensor, loss_type: str) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
552
+ """Computes the loss based on the specified loss type.
553
+
554
+ Args:
555
+ y (torch.Tensor): Decoder output. Shape (B, M, D).
556
+ target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
557
+ decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
558
+ decoder_mod_mask (torch.Tensor): Integer mask indicating which tokens belong to which modality. Shape (B, M).
559
+ loss_type (str): The type of loss to compute. Either 'mod' or 'token'.
560
+
561
+ Returns:
562
+ Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total loss and dictionary of loss for each modality.
563
+ """
564
+ if loss_type in ['mod', 'modality']:
565
+ loss, mod_loss = self.forward_mod_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask)
566
+ elif loss_type == 'token':
567
+ loss, mod_loss = self.forward_token_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask)
568
+ else:
569
+ raise ValueError("Invalid loss type")
570
+
571
+ return loss, mod_loss
572
+
573
+ def forward_mod_loss(self,
574
+ y: torch.Tensor,
575
+ target_ids: torch.Tensor,
576
+ decoder_mod_dict: Dict[str, Any],
577
+ decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
578
+ """Computes the modality-wise loss.
579
+
580
+ Args:
581
+ y (torch.Tensor): Decoder tokens. Shape (B, M, D).
582
+ target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
583
+ decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
584
+ decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M).
585
+
586
+ Returns:
587
+ Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total modality loss and dictionary of loss for each modality.
588
+ """
589
+ mod_loss = {}
590
+ for mod, d in decoder_mod_dict.items():
591
+ idx = self.modality_info[mod]["id"]
592
+ logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
593
+ if logits.numel() == 0:
594
+ # If there are no logits / targets, set mod_loss to 0
595
+ mod_loss[mod] = torch.zeros(1, device=logits.device)
596
+ else:
597
+ loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean')
598
+ mod_loss[mod] = loss
599
+
600
+ loss = sum(mod_loss.values()) / len(mod_loss)
601
+
602
+ return loss, mod_loss
603
+
604
+ def forward_token_loss(self,
605
+ y: torch.Tensor,
606
+ target_ids: torch.Tensor,
607
+ decoder_mod_dict: Dict[str, Any],
608
+ decoder_mod_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
609
+ """Computes the token-wise loss.
610
+
611
+ Args:
612
+ y (torch.Tensor): Decoder tokens. Shape (B, M, D).
613
+ target_ids (torch.Tensor): Ground truth token IDs. Shape (B, M).
614
+ decoder_mod_dict (dict): Dictionary containing tensor information for each modality in the decoder.
615
+ decoder_mod_mask (torch.Tensor): Mask indicating which tokens belong to which modality. Shape (B, M).
616
+
617
+ Returns:
618
+ Tuple[torch.Tensor, Dict[str, torch.Tensor]]: Total token loss and dictionary of loss for each modality.
619
+ """
620
+ mod_loss = {}
621
+ mod_count = {}
622
+
623
+ for mod, d in decoder_mod_dict.items():
624
+ idx = self.modality_info[mod]["id"]
625
+ logits = self.decoder_embeddings[mod].forward_logits(y[decoder_mod_mask == idx])
626
+ if logits.numel() == 0:
627
+ # If there are no logits / targets, set mod_loss to 0
628
+ mod_loss[mod] = torch.zeros(1, device=logits.device)
629
+ mod_count[mod] = 0
630
+ else:
631
+ loss = F.cross_entropy(logits, target_ids[decoder_mod_mask == idx].long(), reduction='mean')
632
+ mod_loss[mod] = loss
633
+ mod_count[mod] = logits.numel()
634
+
635
+ loss = sum([mod_loss[mod] * mod_count[mod] for mod in mod_loss.keys()]) / sum(mod_count.values())
636
+
637
+ return loss, mod_loss
638
+
639
+
640
+ def forward(self,
641
+ mod_dict: Dict[str, Dict[str, torch.Tensor]],
642
+ num_encoder_tokens: int,
643
+ num_decoder_tokens: int,
644
+ loss_type: str = 'mod',
645
+ return_logits: bool = False) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
646
+ """
647
+ Forward pass for the model.
648
+
649
+ Args:
650
+ mod_dict (Dict[str, Dict[str, torch.Tensor]]): Dictionary containing the tensors, masks, and other info for each modality.
651
+ - mod_dict[modality_name]["tensor_name"]: Shape can vary based on tensor_name and modality.
652
+ num_encoder_tokens (int): Number of tokens to keep for the encoder.
653
+ num_decoder_tokens (int): Number of tokens to keep for the decoder.
654
+ loss_type (str, optional): The type of loss to compute. Can be 'mod' (average of loss per modality) or 'token' (average loss per token). Default is 'mod'.
655
+ return_logits (bool, optional): If True, return the logits. Default is False.
656
+
657
+ Returns:
658
+ Union[dict, tuple]:
659
+ - If return_logits is True: Dictionary of logits for each modality.
660
+ - Otherwise: Tuple containing the total loss and dictionary of loss for each modality.
661
+ """
662
+
663
+ # Mod dicts
664
+ encoder_mod_dict = {mod: self.encoder_embeddings[mod](d)
665
+ for mod, d in mod_dict.items()
666
+ if mod in self.encoder_embeddings}
667
+ encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder(encoder_mod_dict, num_encoder_tokens)
668
+
669
+ decoder_mod_dict = {mod: self.decoder_embeddings[mod].forward_embed(d)
670
+ for mod, d in mod_dict.items()
671
+ if mod in self.decoder_embeddings}
672
+ decoder_tokens, decoder_emb, decoder_mask, target_ids, decoder_attention_mask, decoder_mod_mask = self.forward_mask_decoder(decoder_mod_dict, num_decoder_tokens)
673
+
674
+ # Encoder
675
+ x = encoder_tokens + encoder_emb
676
+ x = self.forward_encoder(x, encoder_mask=encoder_mask)
677
+
678
+ # Decoder
679
+ context = self.decoder_proj_context(x) + encoder_emb
680
+ y = decoder_tokens + decoder_emb
681
+ y = self.forward_decoder(y, context, encoder_mask=encoder_mask, decoder_attention_mask=decoder_attention_mask)
682
+
683
+ # Logits
684
+ if return_logits:
685
+ mod_logits = self.forward_logits(y, decoder_mod_dict, decoder_mod_mask, return_all_logits=True)
686
+ return mod_logits
687
+
688
+ # Loss
689
+ loss, mod_loss = self.forward_loss(y, target_ids, decoder_mod_dict, decoder_mod_mask, loss_type)
690
+
691
+ return loss, mod_loss
692
+
693
+
694
+ def freeze_encoder(self, freeze_embeddings=True):
695
+ for param in self.encoder.parameters():
696
+ param.requires_grad = False
697
+
698
+ for param in self.encoder_norm.parameters():
699
+ param.requires_grad = False
700
+
701
+ if freeze_embeddings:
702
+ for param in self.encoder_embeddings.parameters():
703
+ param.requires_grad = False
704
+
705
+ def freeze_encoder_except_specific_embeddings(self, frozen_embedding_domain):
706
+ frozen_embedding_domain = frozen_embedding_domain.split('-')
707
+ for param in self.encoder.parameters():
708
+ param.requires_grad = False
709
+
710
+ for param in self.encoder_norm.parameters():
711
+ param.requires_grad = False
712
+
713
+ for name, param in self.encoder_embeddings.named_parameters():
714
+ if name.split('.')[0] in frozen_embedding_domain:
715
+ param.requires_grad = False
716
+
717
+ def unfreeze_encoder(self, unfreeze_embeddings=True):
718
+ for param in self.encoder.parameters():
719
+ param.requires_grad = True
720
+
721
+ for param in self.encoder_norm.parameters():
722
+ param.requires_grad = True
723
+
724
+ if unfreeze_embeddings:
725
+ for param in self.encoder_embeddings.parameters():
726
+ param.requires_grad = True
727
+
728
+ def freeze_decoder(self, freeze_embeddings=True):
729
+ for param in self.decoder.parameters():
730
+ param.requires_grad = False
731
+
732
+ for param in self.decoder_norm.parameters():
733
+ param.requires_grad = False
734
+
735
+ if freeze_embeddings:
736
+ for param in self.decoder_embeddings.parameters():
737
+ param.requires_grad = False
738
+
739
+ def freeze_decoder_except_specific_embeddings(self, frozen_embedding_domain):
740
+ frozen_embedding_domain = frozen_embedding_domain.split('-')
741
+ for param in self.decoder.parameters():
742
+ param.requires_grad = False
743
+
744
+ for param in self.decoder_norm.parameters():
745
+ param.requires_grad = False
746
+
747
+ for name, param in self.decoder_embeddings.named_parameters():
748
+ if name.split('.')[0] in frozen_embedding_domain:
749
+ param.requires_grad = False
750
+
751
+ def unfreeze_decoder(self, unfreeze_embeddings=True):
752
+ for param in self.decoder.parameters():
753
+ param.requires_grad = True
754
+
755
+ for param in self.decoder_norm.parameters():
756
+ param.requires_grad = True
757
+
758
+ if unfreeze_embeddings:
759
+ for param in self.decoder_embeddings.parameters():
760
+ param.requires_grad = True
761
+
762
+ def freeze_shared_params(self):
763
+ self.freeze_encoder(freeze_embeddings=False)
764
+ self.freeze_decoder(freeze_embeddings=False)
765
+
766
+ def freeze_params_except_specific_embeddings(self, frozen_embedding_domain):
767
+ self.freeze_encoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain)
768
+ self.freeze_decoder_except_specific_embeddings(frozen_embedding_domain=frozen_embedding_domain)
769
+
770
+ def unfreeze_shared_params(self):
771
+ self.unfreeze_encoder(unfreeze_embeddings=False)
772
+ self.unfreeze_decoder(unfreeze_embeddings=False)
773
+
774
+ def unfreeze_all(self):
775
+ self.unfreeze_encoder(unfreeze_embeddings=True)
776
+ self.unfreeze_decoder(unfreeze_embeddings=True)
777
+
778
+
779
+ ################################################
780
+
781
+ # Wrapper for easy loading with Huggingface Hub
782
+
783
+ class FM(FourM, PyTorchModelHubMixin):
784
+ """Wrapper around FourM for easy loading with Huggingface Hub.
785
+
786
+ Args:
787
+ config (dict): Dictionary containing the model and modality configuration,
788
+ used for loading from Huggingface Hub.
789
+ """
790
+ def __init__(self, config: dict):
791
+
792
+ config = copy.deepcopy(config)
793
+
794
+ all_domains = sorted(list(set(config['domains_in']) | set(config['domains_out'])))
795
+ modality_info = {mod: MODALITY_INFO[mod] for mod in all_domains}
796
+
797
+ encoder_embeddings = {}
798
+ for mod in config['domains_in']:
799
+ info = modality_info[mod]
800
+ if info.get("encoder_embedding", None) is not None:
801
+ if info["type"] == "img":
802
+ image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size'])
803
+ encoder_embeddings[mod] = info["encoder_embedding"](patch_size=patch_size, image_size=image_size)
804
+ else:
805
+ encoder_embeddings[mod] = info["encoder_embedding"]()
806
+
807
+ decoder_embeddings = {}
808
+ for mod in config['domains_out']:
809
+ info = modality_info[mod]
810
+ if info.get("decoder_embedding", None) is not None:
811
+ if info["type"] == "img":
812
+ image_size, patch_size = info.get('input_size', config['image_size']), info.get('patch_size', config['patch_size'])
813
+ decoder_embeddings[mod] = info["decoder_embedding"](patch_size=patch_size, image_size=image_size, share_embedding=False)
814
+ else:
815
+ decoder_embeddings[mod] = info["decoder_embedding"](share_embedding=False)
816
+
817
+ config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias'])
818
+ config['act_layer'] = getattr(torch.nn, config['act_layer'])
819
+
820
+ del config['norm_bias']
821
+ del config['domains_in']
822
+ del config['domains_out']
823
+ del config['image_size']
824
+ del config['patch_size']
825
+
826
+ super().__init__(
827
+ encoder_embeddings=encoder_embeddings,
828
+ decoder_embeddings=decoder_embeddings,
829
+ modality_info=modality_info,
830
+ **config
831
+ )
832
+
833
+
834
+ ################################################
835
+
836
+ # Model definitions
837
+
838
+ # GELU variants
839
+ @register_model
840
+ def fm_tiny_6e_6d_gelu(
841
+ encoder_embeddings: Dict[str, nn.Module],
842
+ decoder_embeddings: Dict[str, nn.Module],
843
+ **kwargs):
844
+ model = FourM(
845
+ encoder_embeddings=encoder_embeddings,
846
+ decoder_embeddings=decoder_embeddings,
847
+ encoder_depth=6,
848
+ decoder_depth=6,
849
+ dim=384,
850
+ num_heads=6,
851
+ mlp_ratio=4,
852
+ qkv_bias=True,
853
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
854
+ **kwargs
855
+ )
856
+ return model
857
+
858
+
859
+ @register_model
860
+ def fm_small_8e_8d_gelu(
861
+ encoder_embeddings: Dict[str, nn.Module],
862
+ decoder_embeddings: Dict[str, nn.Module],
863
+ **kwargs):
864
+ model = FourM(
865
+ encoder_embeddings=encoder_embeddings,
866
+ decoder_embeddings=decoder_embeddings,
867
+ encoder_depth=8,
868
+ decoder_depth=8,
869
+ dim=512,
870
+ num_heads=8,
871
+ mlp_ratio=4,
872
+ qkv_bias=True,
873
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
874
+ **kwargs
875
+ )
876
+ return model
877
+
878
+
879
+ @register_model
880
+ def fm_base_12e_12d_gelu(
881
+ encoder_embeddings: Dict[str, nn.Module],
882
+ decoder_embeddings: Dict[str, nn.Module],
883
+ **kwargs):
884
+ model = FourM(
885
+ encoder_embeddings=encoder_embeddings,
886
+ decoder_embeddings=decoder_embeddings,
887
+ encoder_depth=12,
888
+ decoder_depth=12,
889
+ dim=768,
890
+ num_heads=12,
891
+ mlp_ratio=4,
892
+ qkv_bias=True,
893
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
894
+ **kwargs
895
+ )
896
+ return model
897
+
898
+
899
+ @register_model
900
+ def fm_large_24e_24d_gelu(
901
+ encoder_embeddings: Dict[str, nn.Module],
902
+ decoder_embeddings: Dict[str, nn.Module],
903
+ **kwargs):
904
+ model = FourM(
905
+ encoder_embeddings=encoder_embeddings,
906
+ decoder_embeddings=decoder_embeddings,
907
+ encoder_depth=24,
908
+ decoder_depth=24,
909
+ dim=1024,
910
+ num_heads=16,
911
+ mlp_ratio=4,
912
+ qkv_bias=True,
913
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
914
+ **kwargs
915
+ )
916
+ return model
917
+
918
+ @register_model
919
+ def fm_xlarge_24e_24d_gelu(
920
+ encoder_embeddings: Dict[str, nn.Module],
921
+ decoder_embeddings: Dict[str, nn.Module],
922
+ **kwargs):
923
+ model = FourM(
924
+ encoder_embeddings=encoder_embeddings,
925
+ decoder_embeddings=decoder_embeddings,
926
+ encoder_depth=24,
927
+ decoder_depth=24,
928
+ dim=2048,
929
+ num_heads=32,
930
+ mlp_ratio=4,
931
+ qkv_bias=True,
932
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
933
+ **kwargs
934
+ )
935
+ return model
936
+
937
+
938
+ # SwiGLU variants
939
+ @register_model
940
+ def fm_tiny_6e_6d_swiglu_nobias(
941
+ encoder_embeddings: Dict[str, nn.Module],
942
+ decoder_embeddings: Dict[str, nn.Module],
943
+ **kwargs):
944
+ model = FourM(
945
+ encoder_embeddings=encoder_embeddings,
946
+ decoder_embeddings=decoder_embeddings,
947
+ encoder_depth=6,
948
+ decoder_depth=6,
949
+ dim=384,
950
+ num_heads=6,
951
+ mlp_ratio=4,
952
+ qkv_bias=False,
953
+ proj_bias=False,
954
+ mlp_bias=False,
955
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
956
+ act_layer=nn.SiLU,
957
+ gated_mlp=True,
958
+ **kwargs
959
+ )
960
+ return model
961
+
962
+
963
+ @register_model
964
+ def fm_small_8e_8d_swiglu_nobias(
965
+ encoder_embeddings: Dict[str, nn.Module],
966
+ decoder_embeddings: Dict[str, nn.Module],
967
+ **kwargs):
968
+ model = FourM(
969
+ encoder_embeddings=encoder_embeddings,
970
+ decoder_embeddings=decoder_embeddings,
971
+ encoder_depth=8,
972
+ decoder_depth=8,
973
+ dim=512,
974
+ num_heads=8,
975
+ mlp_ratio=4,
976
+ qkv_bias=False,
977
+ proj_bias=False,
978
+ mlp_bias=False,
979
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
980
+ act_layer=nn.SiLU,
981
+ gated_mlp=True,
982
+ **kwargs
983
+ )
984
+ return model
985
+
986
+
987
+ @register_model
988
+ def fm_base_12e_12d_swiglu_nobias(
989
+ encoder_embeddings: Dict[str, nn.Module],
990
+ decoder_embeddings: Dict[str, nn.Module],
991
+ **kwargs):
992
+ model = FourM(
993
+ encoder_embeddings=encoder_embeddings,
994
+ decoder_embeddings=decoder_embeddings,
995
+ encoder_depth=12,
996
+ decoder_depth=12,
997
+ dim=768,
998
+ num_heads=12,
999
+ mlp_ratio=4,
1000
+ qkv_bias=False,
1001
+ proj_bias=False,
1002
+ mlp_bias=False,
1003
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
1004
+ act_layer=nn.SiLU,
1005
+ gated_mlp=True,
1006
+ **kwargs
1007
+ )
1008
+ return model
1009
+
1010
+ @register_model
1011
+ def fm_large_24e_24d_swiglu_nobias(
1012
+ encoder_embeddings: Dict[str, nn.Module],
1013
+ decoder_embeddings: Dict[str, nn.Module],
1014
+ **kwargs):
1015
+ model = FourM(
1016
+ encoder_embeddings=encoder_embeddings,
1017
+ decoder_embeddings=decoder_embeddings,
1018
+ encoder_depth=24,
1019
+ decoder_depth=24,
1020
+ dim=1024,
1021
+ num_heads=16,
1022
+ mlp_ratio=4,
1023
+ qkv_bias=False,
1024
+ proj_bias=False,
1025
+ mlp_bias=False,
1026
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
1027
+ act_layer=nn.SiLU,
1028
+ gated_mlp=True,
1029
+ **kwargs
1030
+ )
1031
+ return model
1032
+
1033
+ @register_model
1034
+ def fm_xlarge_24e_24d_swiglu_nobias(
1035
+ encoder_embeddings: Dict[str, nn.Module],
1036
+ decoder_embeddings: Dict[str, nn.Module],
1037
+ **kwargs):
1038
+ model = FourM(
1039
+ encoder_embeddings=encoder_embeddings,
1040
+ decoder_embeddings=decoder_embeddings,
1041
+ encoder_depth=24,
1042
+ decoder_depth=24,
1043
+ dim=2048,
1044
+ num_heads=32,
1045
+ mlp_ratio=4,
1046
+ qkv_bias=False,
1047
+ proj_bias=False,
1048
+ mlp_bias=False,
1049
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
1050
+ act_layer=nn.SiLU,
1051
+ gated_mlp=True,
1052
+ **kwargs
1053
+ )
1054
+ return model
1055
+
1056
+ # SwiGLU + QKNorm variants
1057
+
1058
+
1059
+ @register_model
1060
+ def fm_base_12e_12d_swiglu_qknorm_nobias(
1061
+ encoder_embeddings: Dict[str, nn.Module],
1062
+ decoder_embeddings: Dict[str, nn.Module],
1063
+ **kwargs):
1064
+ model = FourM(
1065
+ encoder_embeddings=encoder_embeddings,
1066
+ decoder_embeddings=decoder_embeddings,
1067
+ encoder_depth=12,
1068
+ decoder_depth=12,
1069
+ dim=768,
1070
+ num_heads=12,
1071
+ mlp_ratio=4,
1072
+ qkv_bias=False,
1073
+ proj_bias=False,
1074
+ mlp_bias=False,
1075
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
1076
+ act_layer=nn.SiLU,
1077
+ gated_mlp=True,
1078
+ qk_norm=True,
1079
+ **kwargs
1080
+ )
1081
+ return model
1082
+
1083
+
1084
+ @register_model
1085
+ def fm_large_24e_24d_swiglu_qknorm_nobias(
1086
+ encoder_embeddings: Dict[str, nn.Module],
1087
+ decoder_embeddings: Dict[str, nn.Module],
1088
+ **kwargs):
1089
+ model = FourM(
1090
+ encoder_embeddings=encoder_embeddings,
1091
+ decoder_embeddings=decoder_embeddings,
1092
+ encoder_depth=24,
1093
+ decoder_depth=24,
1094
+ dim=1024,
1095
+ num_heads=16,
1096
+ mlp_ratio=4,
1097
+ qkv_bias=False,
1098
+ proj_bias=False,
1099
+ mlp_bias=False,
1100
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
1101
+ act_layer=nn.SiLU,
1102
+ gated_mlp=True,
1103
+ qk_norm=True,
1104
+ **kwargs
1105
+ )
1106
+ return model
1107
+
1108
+ @register_model
1109
+ def fm_xlarge_24e_24d_swiglu_qknorm_nobias(
1110
+ encoder_embeddings: Dict[str, nn.Module],
1111
+ decoder_embeddings: Dict[str, nn.Module],
1112
+ **kwargs):
1113
+ model = FourM(
1114
+ encoder_embeddings=encoder_embeddings,
1115
+ decoder_embeddings=decoder_embeddings,
1116
+ encoder_depth=24,
1117
+ decoder_depth=24,
1118
+ dim=2048,
1119
+ num_heads=32,
1120
+ mlp_ratio=4,
1121
+ qkv_bias=False,
1122
+ proj_bias=False,
1123
+ mlp_bias=False,
1124
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
1125
+ act_layer=nn.SiLU,
1126
+ gated_mlp=True,
1127
+ qk_norm=True,
1128
+ **kwargs
1129
+ )
1130
+ return model
fourm/models/fm_utils.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # Some functions are based on the timm code base
16
+ # https://github.com/huggingface/pytorch-image-models
17
+ # --------------------------------------------------------
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from einops import rearrange
23
+
24
+
25
+ def pair(t):
26
+ return t if isinstance(t, tuple) else (t, t)
27
+
28
+ def softmax1(tensor):
29
+ # See https://www.evanmiller.org/attention-is-off-by-one.html
30
+ return F.pad(tensor, (0,1)).softmax(dim=-1)[...,:-1]
31
+
32
+ def build_1d_sincos_posemb(max_len, embed_dim=1024, temperature=10000.):
33
+ """Sine-cosine positional embeddings from MoCo-v3, adapted back to 1d
34
+
35
+ Returns positional embedding of shape (1, N, D)
36
+ """
37
+ arange = torch.arange(max_len, dtype=torch.float32) # Shape (N,)
38
+ assert embed_dim % 2 == 0, 'Embed dimension must be divisible by 2 for 1D sin-cos position embedding'
39
+ pos_dim = embed_dim // 2
40
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim # Shape (D/2,)
41
+ omega = 1. / (temperature ** omega)
42
+ out = torch.einsum('n,d->nd', [arange, omega]) # Outer product, shape (N, D/2)
43
+ pos_emb = torch.cat([torch.sin(out), torch.cos(out)], dim=1).unsqueeze(0) # Shape (1, N, D)
44
+ return pos_emb
45
+
46
+ def build_2d_sincos_posemb(h, w, embed_dim=1024, temperature=10000.0):
47
+ """Sine-cosine positional embeddings as used in MoCo-v3
48
+
49
+ Returns positional embedding of shape (1, N, D) where N = W*H
50
+ """
51
+ grid_w = torch.arange(w, dtype=torch.float32) # Shape (W,)
52
+ grid_h = torch.arange(h, dtype=torch.float32) # Shape (H, )
53
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') # Shapes (W, H)
54
+ assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
55
+ pos_dim = embed_dim // 4
56
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim # Shape (D/4,)
57
+ omega = 1. / (temperature ** omega)
58
+ out_w = torch.einsum('n,d->nd', [grid_w.reshape(-1), omega]) # Outer product, shape (W*H, D/4)
59
+ out_h = torch.einsum('n,d->nd', [grid_h.reshape(-1), omega]) # Outer product, shape (W*H, D/4)
60
+ pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1).unsqueeze(0) # Shape (1, W*H, D)
61
+ return pos_emb
62
+
63
+
64
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
65
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
66
+ Implementation from timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
67
+ """
68
+ if drop_prob == 0. or not training:
69
+ return x
70
+ keep_prob = 1 - drop_prob
71
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
72
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
73
+ random_tensor.floor_() # binarize
74
+ output = x.div(keep_prob) * random_tensor
75
+ return output
76
+
77
+
78
+ class DropPath(nn.Module):
79
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
80
+ """
81
+
82
+ def __init__(self, drop_prob=None):
83
+ super(DropPath, self).__init__()
84
+ self.drop_prob = drop_prob
85
+
86
+ def forward(self, x):
87
+ return drop_path(x, self.drop_prob, self.training)
88
+
89
+ def extra_repr(self) -> str:
90
+ return 'p={}'.format(self.drop_prob)
91
+
92
+
93
+ class LayerNorm(nn.Module):
94
+ """Custom implementation of LayerNorm with the option to disable the bias term"""
95
+ def __init__(self, normalized_shape: int, eps=1e-5, bias=True):
96
+ super().__init__()
97
+ self.eps = eps
98
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
99
+ if bias:
100
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
101
+ else:
102
+ self.register_buffer("bias", torch.zeros(normalized_shape))
103
+
104
+ # Normalized shape must be a tuple for F.layer_norm
105
+ self.normalized_shape = (normalized_shape,)
106
+
107
+ def forward(self, x):
108
+ return nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, eps=self.eps)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., bias=True):
113
+ super().__init__()
114
+ out_features = out_features or in_features
115
+ hidden_features = hidden_features or in_features
116
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
117
+ self.act = act_layer()
118
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
119
+ self.drop = nn.Dropout(drop)
120
+
121
+ def forward(self, x):
122
+ x = self.fc1(x)
123
+ x = self.act(x)
124
+ x = self.fc2(x)
125
+ x = self.drop(x)
126
+ return x
127
+
128
+
129
+ class GatedMlp(nn.Module):
130
+ """Implements SwiGLU and other gated feed-forward layers from Noam Shazeer's paper: https://arxiv.org/abs/2002.05202
131
+ """
132
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, bias=True):
133
+ super().__init__()
134
+ out_features = out_features or in_features
135
+ # If gated, multiply hidden_dim by 2/3 to account for extra matmul
136
+ hidden_features = int(2 * (hidden_features or in_features) / 3)
137
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
138
+ self.act = act_layer()
139
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
140
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
141
+
142
+ def forward(self, x):
143
+ x = self.fc2(self.act(self.fc1(x)) * self.fc3(x))
144
+ return x
145
+
146
+
147
+ class Attention(nn.Module):
148
+ def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, attn_drop=0., proj_drop=0., allow_zero_attn=False):
149
+ super().__init__()
150
+ self.num_heads = num_heads
151
+ head_dim = dim // num_heads
152
+ self.scale = head_dim ** -0.5
153
+ self.allow_zero_attn = allow_zero_attn
154
+
155
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
156
+ self.attn_drop = nn.Dropout(attn_drop)
157
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
158
+ self.proj_drop = nn.Dropout(proj_drop)
159
+
160
+ def forward(self, x, mask=None):
161
+ B, N, C = x.shape
162
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
163
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
164
+
165
+ attn = (q @ k.transpose(-2, -1)) * self.scale
166
+
167
+ if mask is not None:
168
+ mask = mask.unsqueeze(1) # Unsqueeze attention mask for multi-head
169
+ attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
170
+
171
+ if self.allow_zero_attn:
172
+ attn = softmax1(attn)
173
+ else:
174
+ attn = attn.softmax(dim=-1)
175
+ attn = self.attn_drop(attn)
176
+
177
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
178
+ x = self.proj(x)
179
+ x = self.proj_drop(x)
180
+ return x
181
+
182
+ class CrossAttention(nn.Module):
183
+ def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, attn_drop=0., proj_drop=0., allow_zero_attn=False):
184
+ super().__init__()
185
+ self.num_heads = num_heads
186
+ head_dim = dim // num_heads
187
+ self.scale = head_dim ** -0.5
188
+ self.allow_zero_attn = allow_zero_attn
189
+
190
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
191
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
192
+
193
+ self.attn_drop = nn.Dropout(attn_drop)
194
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
195
+ self.proj_drop = nn.Dropout(proj_drop)
196
+
197
+ def forward(self, x, context, mask=None):
198
+ B, N, C = x.shape
199
+ _, M, _ = context.shape
200
+
201
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
202
+ kv = self.kv(context).reshape(B, M, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
203
+ k, v = kv[0], kv[1]
204
+
205
+ attn = (q @ k.transpose(-2, -1)) * self.scale
206
+ if mask is not None:
207
+ mask = rearrange(mask, "b n m -> b 1 n m") # Unsqueeze / reshape for multi-head
208
+ attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
209
+
210
+ if self.allow_zero_attn:
211
+ attn = softmax1(attn)
212
+ else:
213
+ attn = attn.softmax(dim=-1)
214
+ attn = self.attn_drop(attn)
215
+
216
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
217
+ x = self.proj(x)
218
+ x = self.proj_drop(x)
219
+ return x
220
+
221
+
222
+ class NormAttention(nn.Module):
223
+ def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, norm_layer=nn.LayerNorm, attn_drop=0., proj_drop=0., allow_zero_attn=False):
224
+ super().__init__()
225
+ self.num_heads = num_heads
226
+ head_dim = dim // num_heads
227
+ self.scale = head_dim ** -0.5
228
+ self.allow_zero_attn = allow_zero_attn
229
+
230
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
231
+ self.attn_drop = nn.Dropout(attn_drop)
232
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
233
+ self.proj_drop = nn.Dropout(proj_drop)
234
+
235
+ self.q_norm = norm_layer(head_dim)
236
+ self.k_norm = norm_layer(head_dim)
237
+
238
+ def forward(self, x, mask=None):
239
+ B, N, C = x.shape
240
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
241
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
242
+
243
+ q = self.q_norm(q)
244
+ k = self.k_norm(k)
245
+
246
+ attn = (q @ k.transpose(-2, -1)) * self.scale
247
+
248
+ if mask is not None:
249
+ mask = mask.unsqueeze(1) # Unsqueeze for multi-head
250
+ attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
251
+
252
+ if self.allow_zero_attn:
253
+ attn = softmax1(attn)
254
+ else:
255
+ attn = attn.softmax(dim=-1)
256
+ attn = self.attn_drop(attn)
257
+
258
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
259
+ x = self.proj(x)
260
+ x = self.proj_drop(x)
261
+ return x
262
+
263
+
264
+ class NormCrossAttention(nn.Module):
265
+ def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True, norm_layer=nn.LayerNorm, attn_drop=0., proj_drop=0., allow_zero_attn=False):
266
+ super().__init__()
267
+ self.num_heads = num_heads
268
+ head_dim = dim // num_heads
269
+ self.scale = head_dim ** -0.5
270
+ self.allow_zero_attn = allow_zero_attn
271
+
272
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
273
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
274
+
275
+ self.attn_drop = nn.Dropout(attn_drop)
276
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
277
+ self.proj_drop = nn.Dropout(proj_drop)
278
+
279
+ self.q_norm = norm_layer(head_dim)
280
+ self.k_norm = norm_layer(head_dim)
281
+
282
+ def forward(self, x, context, mask=None):
283
+ B, N, C = x.shape
284
+ _, M, _ = context.shape
285
+
286
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
287
+ kv = self.kv(context).reshape(B, M, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
288
+ k, v = kv[0], kv[1]
289
+
290
+ q = self.q_norm(q)
291
+ k = self.k_norm(k)
292
+
293
+ attn = (q @ k.transpose(-2, -1)) * self.scale
294
+ if mask is not None:
295
+ mask = rearrange(mask, "b n m -> b 1 n m") # Unsqueeze / reshape for multi-head
296
+ attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max)
297
+
298
+ if self.allow_zero_attn:
299
+ attn = softmax1(attn)
300
+ else:
301
+ attn = attn.softmax(dim=-1)
302
+ attn = self.attn_drop(attn)
303
+
304
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
305
+ x = self.proj(x)
306
+ x = self.proj_drop(x)
307
+ return x
308
+
309
+
310
+ class Block(nn.Module):
311
+
312
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, proj_bias=True, mlp_bias=True, drop=0., attn_drop=0.,
313
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gated_mlp=False, qk_norm=False, allow_zero_attn=False):
314
+ super().__init__()
315
+ self.norm1 = norm_layer(dim)
316
+
317
+ if not qk_norm:
318
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
319
+ else:
320
+ self.attn = NormAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, norm_layer=norm_layer, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
321
+
322
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
323
+ self.norm2 = norm_layer(dim)
324
+ mlp_hidden_dim = int(dim * mlp_ratio)
325
+
326
+ if not gated_mlp:
327
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias, drop=drop)
328
+ else:
329
+ self.mlp = GatedMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias)
330
+
331
+ def forward(self, x, mask=None):
332
+ x = x + self.drop_path(self.attn(self.norm1(x), mask))
333
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
334
+ return x
335
+
336
+
337
+ class DecoderBlock(nn.Module):
338
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, proj_bias=True, mlp_bias=True, drop=0., attn_drop=0.,
339
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gated_mlp=False, qk_norm=False, allow_zero_attn=False):
340
+ super().__init__()
341
+ self.norm1 = norm_layer(dim)
342
+
343
+ if not qk_norm:
344
+ self.self_attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
345
+ self.cross_attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
346
+ else:
347
+ self.self_attn = NormAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, norm_layer=norm_layer, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
348
+ self.cross_attn = NormCrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, norm_layer=norm_layer, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
349
+
350
+
351
+ self.query_norm = norm_layer(dim)
352
+ self.context_norm = norm_layer(dim)
353
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
354
+ self.norm2 = norm_layer(dim)
355
+ mlp_hidden_dim = int(dim * mlp_ratio)
356
+
357
+ if not gated_mlp:
358
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias, drop=drop)
359
+ else:
360
+ self.mlp = GatedMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, bias=mlp_bias)
361
+
362
+ def forward(self, x, context, sa_mask=None, xa_mask=None):
363
+ x = x + self.drop_path(self.self_attn(self.norm1(x), sa_mask))
364
+ x = x + self.drop_path(self.cross_attn(self.query_norm(x), self.context_norm(context), xa_mask))
365
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
366
+ return x
367
+
368
+
369
+ class CrossAttentionBlock(nn.Module):
370
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
371
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, gated_mlp=False, allow_zero_attn=False):
372
+ super().__init__()
373
+ self.cross_attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, allow_zero_attn=allow_zero_attn)
374
+ self.query_norm = norm_layer(dim)
375
+ self.context_norm = norm_layer(dim)
376
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
377
+ self.norm2 = norm_layer(dim)
378
+ mlp_hidden_dim = int(dim * mlp_ratio)
379
+ if not gated_mlp:
380
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
381
+ else:
382
+ self.mlp = GatedMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
383
+
384
+ def forward(self, x, context, xa_mask=None, **kwargs):
385
+ x = x + self.drop_path(self.cross_attn(self.query_norm(x), self.context_norm(context), xa_mask))
386
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
387
+ return x
fourm/models/fm_vit.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import copy
16
+ from functools import partial
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from fourm.utils.timm.registry import register_model
23
+ from huggingface_hub import PyTorchModelHubMixin
24
+
25
+ from .encoder_embeddings import ImageEncoderEmbedding
26
+ from .fm_utils import Block, LayerNorm
27
+ from fourm.data.modality_info import MODALITY_INFO
28
+
29
+
30
+ __all__ = [
31
+ # GELU models
32
+ 'fm_vit_tiny_6e_gelu',
33
+ 'fm_vit_small_8e_gelu',
34
+ 'fm_vit_base_12e_gelu',
35
+ 'fm_vit_large_24e_gelu',
36
+ 'fm_vit_xlarge_24e_gelu',
37
+ # SwiGLU models
38
+ 'fm_vit_tiny_6e_swiglu_nobias',
39
+ 'fm_vit_small_8e_swiglu_nobias',
40
+ 'fm_vit_base_12e_swiglu_nobias',
41
+ 'fm_vit_large_24e_swiglu_nobias',
42
+ 'fm_vit_xlarge_24e_swiglu_nobias',
43
+ # SwiGLU + QKNorm models
44
+ 'fm_vit_base_12e_swiglu_qknorm_nobias',
45
+ 'fm_vit_large_24e_swiglu_qknorm_nobias',
46
+ 'fm_vit_xlarge_24e_swiglu_qknorm_nobias',
47
+ ]
48
+
49
+ class FourMViT(nn.Module):
50
+ """Modified 4M model, adapted to behave as a simple RGB-only ViT.
51
+
52
+ Args:
53
+ img_size (int): Input image size.
54
+ patch_size (int): Patch size.
55
+ in_chans (int): Number of input image channels.
56
+ dim (int): Patch embedding dimension.
57
+ encoder_depth (int): Depth of ViT / number of encoder blocks.
58
+ num_heads (int): Number of attention heads in each ViT block.
59
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
60
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
61
+ proj_bias (bool): If True, adds a bias to the attention out proj layer.
62
+ mlp_bias (bool): If True, adds a learnable bias for the feedforward.
63
+ drop_path_rate (float): Stochastic depth rate.
64
+ drop_rate (float): Dropout rate.
65
+ attn_drop_rate (float): Attention dropout rate.
66
+ act_layer (nn.Module): Activation layer.
67
+ norm_layer (nn.Module): Normalization layer.
68
+ gated_mlp (bool): If True, makes the feedforward gated (e.g., for SwiGLU)
69
+ qk_norm (bool): If True, normalizes the query and keys (as in ViT-22B)
70
+ use_act_checkpoint (bool): If True, use activation checkpointing.
71
+ encoder_norm (bool): If True, adds a norm layer after the last encoder block.
72
+ output_head (Optional[nn.Module]): Optional output head after the encoder
73
+ """
74
+ def __init__(
75
+ self,
76
+ img_size=224,
77
+ patch_size=16,
78
+ in_chans=3,
79
+ dim=768,
80
+ encoder_depth=12,
81
+ num_heads=12,
82
+ mlp_ratio=4.0,
83
+ qkv_bias: bool = True,
84
+ proj_bias: bool = True,
85
+ mlp_bias: bool = True,
86
+ drop_path_rate: float =0.0,
87
+ drop_rate: float = 0.0,
88
+ attn_drop_rate: float =0.0,
89
+ act_layer: torch.Tensor =nn.GELU,
90
+ norm_layer: Union[partial, nn.Module] = partial(LayerNorm, eps=1e-6),
91
+ gated_mlp: bool = False, # Make the feedforward gated for e.g. SwiGLU
92
+ qk_norm: bool = False,
93
+ encoder_norm = True,
94
+ output_head: Optional[nn.Module] = None,
95
+ ):
96
+ super().__init__()
97
+ self.img_size = img_size
98
+ self.init_std = 0.02
99
+ rgb_embedding = ImageEncoderEmbedding(num_channels=in_chans, patch_size=patch_size,
100
+ dim_tokens=dim, sincos_pos_emb=True, image_size=img_size)
101
+ self.num_patches = rgb_embedding.num_patches
102
+ self.encoder_embeddings = nn.ModuleDict({f"rgb@{img_size}": rgb_embedding})
103
+
104
+ # stochastic depth decay rule
105
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)]
106
+
107
+ self.encoder = nn.ModuleList([
108
+ Block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, mlp_bias=mlp_bias,
109
+ drop_path=dpr[i], drop=drop_rate, attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer,
110
+ gated_mlp=gated_mlp, qk_norm=qk_norm)
111
+ for i in range(encoder_depth)
112
+ ])
113
+
114
+ self.encoder_norm = norm_layer(dim) if encoder_norm else nn.Identity()
115
+
116
+ # Weight init
117
+ self.init_weights()
118
+
119
+ # Classification head is initialized after init_weights() to allow for special init scale
120
+ if output_head is not None:
121
+ self.output_head = output_head
122
+ if hasattr(self.output_head, 'init'):
123
+ self.output_head.init(dim)
124
+ else:
125
+ self.output_head = nn.Identity()
126
+
127
+ def init_weights(self):
128
+ """Weight initialization following MAE's initialization scheme"""
129
+
130
+ for name, m in self.named_modules():
131
+ # Skipping tokenizers to avoid reinitializing them
132
+ if "tokenizer" in name:
133
+ continue
134
+ # Linear
135
+ elif isinstance(m, nn.Linear):
136
+ if 'qkv' in name:
137
+ # treat the weights of Q, K, V separately
138
+ val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
139
+ nn.init.uniform_(m.weight, -val, val)
140
+ elif 'kv' in name:
141
+ # treat the weights of K, V separately
142
+ val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
143
+ nn.init.uniform_(m.weight, -val, val)
144
+ else:
145
+ nn.init.xavier_uniform_(m.weight)
146
+ if isinstance(m, nn.Linear) and m.bias is not None:
147
+ nn.init.constant_(m.bias, 0)
148
+ # LayerNorm
149
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, LayerNorm):
150
+ nn.init.constant_(m.weight, 1.0)
151
+ nn.init.constant_(m.bias, 0)
152
+
153
+ # Embedding
154
+ elif isinstance(m, nn.Embedding):
155
+ nn.init.normal_(m.weight, std=self.init_std)
156
+ # Conv2d
157
+ elif isinstance(m, nn.Conv2d):
158
+ if '.proj' in name:
159
+ # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
160
+ w = m.weight.data
161
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
162
+
163
+ def get_num_layers_encoder(self):
164
+ return len(self.encoder)
165
+
166
+ def get_num_layers(self):
167
+ return self.get_num_layers_encoder()
168
+
169
+ @torch.jit.ignore
170
+ def no_weight_decay(self):
171
+ no_wd_set = set()
172
+
173
+ for mod, emb_module in self.encoder_embeddings.items():
174
+ if hasattr(emb_module, 'no_weight_decay'):
175
+ to_skip = emb_module.no_weight_decay()
176
+ to_skip = set([f'encoder_embeddings.{mod}.{name}' for name in to_skip])
177
+ no_wd_set = no_wd_set | to_skip
178
+
179
+ return no_wd_set
180
+
181
+
182
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
183
+ """
184
+ Forward pass of the model.
185
+
186
+ Args:
187
+ x (torch.Tensor): Input tensor. Shape (B, C, H, W)
188
+
189
+ Returns:
190
+ torch.Tensor: Output tensor. Shape (B, num_classes).
191
+ """
192
+ rgb_dict = {'tensor': x}
193
+ rgb_dict = self.encoder_embeddings[f'rgb@{self.img_size}'](rgb_dict)
194
+
195
+ # Add embeddings to patchified RGB image
196
+ x = rgb_dict['x'] + rgb_dict['emb'] # Shape: (B, N, D) with N = num_patches
197
+
198
+ for blk in self.encoder:
199
+ x = blk(x)
200
+
201
+ x = self.encoder_norm(x) # Shape: (B, N, D)
202
+
203
+ out = self.output_head(x)
204
+
205
+ return out
206
+
207
+
208
+ def freeze_encoder(self, freeze_embeddings=True):
209
+ for param in self.encoder.parameters():
210
+ param.requires_grad = False
211
+
212
+ for param in self.encoder_norm.parameters():
213
+ param.requires_grad = False
214
+
215
+ if freeze_embeddings:
216
+ for param in self.encoder_embeddings.parameters():
217
+ param.requires_grad = False
218
+
219
+ def unfreeze_encoder(self, unfreeze_embeddings=True):
220
+ for param in self.encoder.parameters():
221
+ param.requires_grad = True
222
+
223
+ for param in self.encoder_norm.parameters():
224
+ param.requires_grad = True
225
+
226
+ if unfreeze_embeddings:
227
+ for param in self.encoder_embeddings.parameters():
228
+ param.requires_grad = True
229
+
230
+
231
+ ################################################
232
+
233
+ # Wrapper for easy loading with Huggingface Hub
234
+
235
+ class FMViT(FourMViT, PyTorchModelHubMixin):
236
+ """Wrapper around FourMViT for easy loading with Huggingface Hub.
237
+
238
+ Args:
239
+ config (dict): Dictionary containing the model and modality configuration,
240
+ used for loading from Huggingface Hub.
241
+ output_head (nn.Module): Optional output head.
242
+ """
243
+ def __init__(self, config: dict, output_head: Optional[nn.Module] = None):
244
+
245
+ config = copy.deepcopy(config)
246
+
247
+
248
+
249
+ config['norm_layer'] = partial(LayerNorm, eps=1e-6, bias=config['norm_bias'])
250
+ config['act_layer'] = getattr(torch.nn, config['act_layer'])
251
+
252
+ img_size = config['image_size']
253
+ config['img_size'] = img_size
254
+ config['patch_size'] = MODALITY_INFO[f'rgb@{img_size}'].get('patch_size', config['patch_size'])
255
+ config['in_chans'] = MODALITY_INFO[f'rgb@{img_size}'].get('num_channels', 3)
256
+
257
+ for key in ['image_size', 'norm_bias', 'domains_in', 'domains_out', 'decoder_depth', 'share_modality_embeddings']:
258
+ if key in config:
259
+ del config[key]
260
+
261
+ super().__init__(
262
+ output_head=output_head,
263
+ **config
264
+ )
265
+
266
+
267
+ ################################################
268
+
269
+ # Model definitions
270
+
271
+ # GELU variants
272
+ @register_model
273
+ def fm_vit_tiny_6e_gelu(**kwargs):
274
+ model = FourMViT(
275
+ encoder_depth=6,
276
+ dim=384,
277
+ num_heads=6,
278
+ mlp_ratio=4,
279
+ qkv_bias=True,
280
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
281
+ **kwargs
282
+ )
283
+ return model
284
+
285
+
286
+ @register_model
287
+ def fm_vit_small_8e_gelu(**kwargs):
288
+ model = FourMViT(
289
+ encoder_depth=8,
290
+ dim=512,
291
+ num_heads=8,
292
+ mlp_ratio=4,
293
+ qkv_bias=True,
294
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
295
+ **kwargs
296
+ )
297
+ return model
298
+
299
+
300
+ @register_model
301
+ def fm_vit_base_12e_gelu(**kwargs):
302
+ model = FourMViT(
303
+ encoder_depth=12,
304
+ dim=768,
305
+ num_heads=12,
306
+ mlp_ratio=4,
307
+ qkv_bias=True,
308
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
309
+ **kwargs
310
+ )
311
+ return model
312
+
313
+
314
+ @register_model
315
+ def fm_vit_large_24e_gelu(**kwargs):
316
+ model = FourMViT(
317
+ encoder_depth=24,
318
+ dim=1024,
319
+ num_heads=16,
320
+ mlp_ratio=4,
321
+ qkv_bias=True,
322
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
323
+ **kwargs
324
+ )
325
+ return model
326
+
327
+ @register_model
328
+ def fm_vit_xlarge_24e_gelu(**kwargs):
329
+ model = FourMViT(
330
+ encoder_depth=24,
331
+ dim=2048,
332
+ num_heads=32,
333
+ mlp_ratio=4,
334
+ qkv_bias=True,
335
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
336
+ **kwargs
337
+ )
338
+ return model
339
+
340
+
341
+ # SwiGLU variants
342
+ @register_model
343
+ def fm_vit_tiny_6e_swiglu_nobias(**kwargs):
344
+ model = FourMViT(
345
+ encoder_depth=6,
346
+ dim=384,
347
+ num_heads=6,
348
+ mlp_ratio=4,
349
+ qkv_bias=False,
350
+ proj_bias=False,
351
+ mlp_bias=False,
352
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
353
+ act_layer=nn.SiLU,
354
+ gated_mlp=True,
355
+ **kwargs
356
+ )
357
+ return model
358
+
359
+
360
+ @register_model
361
+ def fm_vit_small_8e_swiglu_nobias(**kwargs):
362
+ model = FourMViT(
363
+ encoder_depth=8,
364
+ dim=512,
365
+ num_heads=8,
366
+ mlp_ratio=4,
367
+ qkv_bias=False,
368
+ proj_bias=False,
369
+ mlp_bias=False,
370
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
371
+ act_layer=nn.SiLU,
372
+ gated_mlp=True,
373
+ **kwargs
374
+ )
375
+ return model
376
+
377
+
378
+ @register_model
379
+ def fm_vit_base_12e_swiglu_nobias(**kwargs):
380
+ model = FourMViT(
381
+ encoder_depth=12,
382
+ dim=768,
383
+ num_heads=12,
384
+ mlp_ratio=4,
385
+ qkv_bias=False,
386
+ proj_bias=False,
387
+ mlp_bias=False,
388
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
389
+ act_layer=nn.SiLU,
390
+ gated_mlp=True,
391
+ **kwargs
392
+ )
393
+ return model
394
+
395
+
396
+ @register_model
397
+ def fm_vit_large_24e_swiglu_nobias(**kwargs):
398
+ model = FourMViT(
399
+ encoder_depth=24,
400
+ dim=1024,
401
+ num_heads=16,
402
+ mlp_ratio=4,
403
+ qkv_bias=False,
404
+ proj_bias=False,
405
+ mlp_bias=False,
406
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
407
+ act_layer=nn.SiLU,
408
+ gated_mlp=True,
409
+ **kwargs
410
+ )
411
+ return model
412
+
413
+ @register_model
414
+ def fm_vit_xlarge_24e_swiglu_nobias(**kwargs):
415
+ model = FourMViT(
416
+ encoder_depth=24,
417
+ dim=2048,
418
+ num_heads=32,
419
+ mlp_ratio=4,
420
+ qkv_bias=False,
421
+ proj_bias=False,
422
+ mlp_bias=False,
423
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
424
+ act_layer=nn.SiLU,
425
+ gated_mlp=True,
426
+ **kwargs
427
+ )
428
+ return model
429
+
430
+ # SwiGLU + QKNorm variants
431
+
432
+ @register_model
433
+ def fm_vit_base_12e_swiglu_qknorm_nobias(**kwargs):
434
+ model = FourMViT(
435
+ encoder_depth=12,
436
+ dim=768,
437
+ num_heads=12,
438
+ mlp_ratio=4,
439
+ qkv_bias=False,
440
+ proj_bias=False,
441
+ mlp_bias=False,
442
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
443
+ act_layer=nn.SiLU,
444
+ gated_mlp=True,
445
+ qk_norm=True,
446
+ **kwargs
447
+ )
448
+ return model
449
+
450
+
451
+ @register_model
452
+ def fm_vit_large_24e_swiglu_qknorm_nobias(**kwargs):
453
+ model = FourMViT(
454
+ encoder_depth=24,
455
+ dim=1024,
456
+ num_heads=16,
457
+ mlp_ratio=4,
458
+ qkv_bias=False,
459
+ proj_bias=False,
460
+ mlp_bias=False,
461
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
462
+ act_layer=nn.SiLU,
463
+ gated_mlp=True,
464
+ qk_norm=True,
465
+ **kwargs
466
+ )
467
+ return model
468
+
469
+ @register_model
470
+ def fm_vit_xlarge_24e_swiglu_qknorm_nobias(**kwargs):
471
+ model = FourMViT(
472
+ encoder_depth=24,
473
+ dim=2048,
474
+ num_heads=32,
475
+ mlp_ratio=4,
476
+ qkv_bias=False,
477
+ proj_bias=False,
478
+ mlp_bias=False,
479
+ norm_layer=partial(LayerNorm, eps=1e-6, bias=False),
480
+ act_layer=nn.SiLU,
481
+ gated_mlp=True,
482
+ qk_norm=True,
483
+ **kwargs
484
+ )
485
+ return model
fourm/models/generate.py ADDED
@@ -0,0 +1,1273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import defaultdict
15
+ from typing import Union, List, Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ from einops import rearrange, repeat
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+
23
+ from fourm.utils import get_sentinel_to_id_mapping, merge_span_masking
24
+ from fourm.utils.generation import cosine_schedule, linear_schedule, onex_temp_schedule, linear_temp_schedule, continue_schedule
25
+ from tqdm import tqdm
26
+ import copy
27
+
28
+
29
+
30
+ def empty_img_modality(mod_dict, key):
31
+ # Input mask
32
+ mod_dict[key]['input_mask'][:] = True
33
+
34
+ # Target Mask
35
+ mod_dict[key]['target_mask'][:] = False
36
+
37
+ return mod_dict
38
+
39
+ def empty_seq_modality(mod_dict, key, s1_id=5):
40
+ # To create an empty sequence, we suppose an input budget of 1, and the rest assigned to targets
41
+
42
+ # Input tensor
43
+ # Input is [S_1], target is [S_1] ...... [S_2]
44
+ # (so [S_1] [S_1] ..... [S_2] when combined)
45
+ mod_dict[key]['tensor'][:] = 0
46
+ mod_dict[key]['tensor'][:,[0,1]] = s1_id # s1_id is id of the first sentinel token ([S_1])
47
+ mod_dict[key]['tensor'][:,-1] = s1_id + 1
48
+
49
+ # Input mask
50
+ # Set first token to input (i.e. 0), rest to target (i.e. 1)
51
+ mod_dict[key]['input_mask'][:] = True
52
+ mod_dict[key]['input_mask'][:,0] = False
53
+
54
+ # Target Mask
55
+ mod_dict[key]['target_mask'] = ~mod_dict[key]['input_mask']
56
+
57
+ # Decoder attn mask
58
+ # WARNING: Not needed / used in GenerationSampler, where causal mask is enforced
59
+ # First token is input, not part of target
60
+ mod_dict[key]['decoder_attention_mask'][:] = 1
61
+ mod_dict[key]['decoder_attention_mask'][:, 0] = 0
62
+
63
+ return mod_dict
64
+
65
+ def empty_seq_emb_modality(mod_dict, key):
66
+ # Tensor
67
+ mod_dict[key]['tensor'] = torch.zeros_like(mod_dict[key]['tensor'])
68
+
69
+ # Input mask
70
+ mod_dict[key]['input_mask'] = torch.ones_like(mod_dict[key]['input_mask'])
71
+ # It is crucial to specify the input mask as such, CFG won't work otherwise!
72
+ mod_dict[key]['input_mask'][:, 0] = False
73
+
74
+ # Target Mask
75
+ mod_dict[key]['target_mask'] = torch.ones_like(mod_dict[key]['target_mask'])
76
+
77
+ # Decoder attn mask
78
+ mod_dict[key]['decoder_attention_mask'][:] = False
79
+
80
+ return mod_dict
81
+
82
+
83
+ def init_empty_target_modality(mod_dict, modality_info, domain, batch_size, num_tokens, device):
84
+ """
85
+ Initializes an empty target modality dictionary for a given domain.
86
+ Used to initialize target modality dictionaries for generation.
87
+ """
88
+ if modality_info[domain]['type'] == 'img':
89
+ # Initialize mod dict
90
+ mod_dict[domain] = {
91
+ 'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int64, device=device),
92
+ 'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device),
93
+ 'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device),
94
+ }
95
+ # Set it to the correct values
96
+ mod_dict = empty_img_modality(mod_dict, domain)
97
+
98
+ elif modality_info[domain]['type'] in ['seq', 'seq_token', 'seq_emb']:
99
+ # Initialize mod dict
100
+ num_tokens = max(num_tokens, 2)
101
+ mod_dict[domain] = {
102
+ 'tensor': torch.zeros((batch_size, num_tokens), dtype=torch.int32, device=device),
103
+ 'input_mask': torch.ones((batch_size, num_tokens), dtype=torch.bool, device=device),
104
+ 'target_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device),
105
+ 'decoder_attention_mask': torch.zeros((batch_size, num_tokens), dtype=torch.bool, device=device),
106
+ }
107
+ # Set it to the correct values
108
+ if modality_info[domain]['type'] in ['seq', 'seq_token']:
109
+ mod_dict = empty_seq_modality(mod_dict, domain)
110
+ elif modality_info[domain]['type'] == 'seq_emb':
111
+ mod_dict = empty_seq_emb_modality(mod_dict, domain)
112
+ else:
113
+ raise ValueError()
114
+
115
+ return mod_dict
116
+
117
+ def init_full_input_modality(mod_dict, modality_info, domain, device, eos_id=3):
118
+ if domain.startswith('rgb'):
119
+ batch_size, _, H, W = mod_dict[domain]['tensor'].shape
120
+ patch_size = modality_info[domain]['patch_size']
121
+ num_tokens = (H // patch_size) * (W // patch_size)
122
+ shape = (batch_size, num_tokens)
123
+ else:
124
+ shape = mod_dict[domain]['tensor'].shape
125
+ if 'input_mask' not in mod_dict[domain]:
126
+ mod_dict[domain]['input_mask'] = torch.zeros(shape, dtype=torch.bool, device=device)
127
+ if 'target_mask' not in mod_dict[domain]:
128
+ mod_dict[domain]['target_mask'] = torch.ones(shape, dtype=torch.bool, device=device)
129
+ if 'decoder_attention_mask' not in mod_dict[domain]:
130
+ mod_dict[domain]['decoder_attention_mask'] = torch.zeros(shape, dtype=torch.bool, device=device)
131
+
132
+ if modality_info[domain]['type'] == 'img':
133
+ mod_dict[domain]['input_mask'][:] = False
134
+ mod_dict[domain]['target_mask'][:] = True
135
+
136
+ elif modality_info[domain]['type'] in ['seq', 'seq_token']:
137
+ if eos_id in mod_dict[domain]['tensor']:
138
+ eos_idx = torch.where(mod_dict[domain]['tensor'] == eos_id)[1][0].item()
139
+ else:
140
+ mod_dict[domain]['tensor'][:,0] = eos_id
141
+ eos_idx = 0
142
+ mod_dict[domain]['input_mask'][:,:eos_idx+1] = False
143
+ mod_dict[domain]['input_mask'][:,eos_idx+1:] = True
144
+ mod_dict[domain]['target_mask'][:] = True
145
+
146
+ elif modality_info[domain]['type'] in ['seq_emb']:
147
+ # T5 caption has the valid mask saved alongside the embeddings
148
+ mod_dict[domain]['input_mask'] = ~mod_dict[domain]['mask_valid']
149
+ mod_dict[domain]['target_mask'] = torch.ones_like(mod_dict[domain]['mask_valid'])
150
+ mod_dict[domain]['decoder_attention_mask'] = torch.zeros_like(mod_dict[domain]['mask_valid'])
151
+
152
+ return mod_dict
153
+
154
+ def custom_text(sample, input_text, eos_token, key, device, text_tokenizer, target_max_len=50, start_token="[S_1]"):
155
+ input_ids = text_tokenizer.encode(input_text).ids
156
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
157
+
158
+ target_text = [start_token]
159
+ target_text.extend(["[PAD]"] * (target_max_len - 2))
160
+ target_text.append(eos_token)
161
+ target_text = " ".join(target_text)
162
+ target_ids = text_tokenizer.encode(target_text).ids
163
+ target_ids = torch.tensor(target_ids).unsqueeze(0)
164
+
165
+ all_ids = torch.cat([input_ids, target_ids], dim=1)
166
+
167
+ input_mask = torch.cat([
168
+ torch.zeros_like(input_ids, dtype=torch.bool),
169
+ torch.ones_like(target_ids, dtype=torch.bool),
170
+ ], dim=1)
171
+
172
+ target_mask = torch.cat([
173
+ torch.ones_like(input_ids, dtype=torch.bool),
174
+ torch.zeros_like(target_ids, dtype=torch.bool),
175
+ ], dim=1)
176
+
177
+ sample[key] = {}
178
+ sample[key]['tensor'] = all_ids.to(device)
179
+ sample[key]['input_mask'] = input_mask.to(device)
180
+ sample[key]['target_mask'] = target_mask.to(device)
181
+ sample[key]['decoder_attention_mask'] = torch.zeros(all_ids.shape, dtype=torch.bool, device=device)
182
+
183
+ return sample
184
+
185
+ def expand_to_batch(mod_dict, batch_size):
186
+ for mod, d in mod_dict.items():
187
+ for k, v in d.items():
188
+ if k in ['tensor', 'input_mask', 'target_mask', 'decoder_attention_mask', 'mask_valid']:
189
+ B = v.shape[0]
190
+ if B == 1:
191
+ mod_dict[mod][k] = repeat(v, "1 ... -> b ...", b=batch_size)
192
+ elif B != batch_size:
193
+ raise ValueError(f"Invalid batch size: {B} instead of {batch_size}")
194
+
195
+ return mod_dict
196
+
197
+ def build_chained_generation_schedules(
198
+ cond_domains: List[str],
199
+ target_domains: List[str],
200
+ tokens_per_target: List[int],
201
+ autoregression_schemes: List[str],
202
+ decoding_steps: List[int],
203
+ token_decoding_schedules: List[str],
204
+ temps: List[float],
205
+ temp_schedules: List[float],
206
+ cfg_scales: List[float],
207
+ cfg_schedules: List[str],
208
+ cfg_grow_conditioning: bool = False,
209
+ modality_info: Optional[dict] = None,
210
+ ):
211
+ """
212
+ Builds a list of chained generation schedules, where each schedule is a tuple of the form:
213
+ (target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains)
214
+
215
+ Args:
216
+ cond_domains: List of conditioning domains
217
+ target_domains: List of target domains
218
+ tokens_per_target: List of number of tokens to decode for each target domain
219
+ autoregression_schemes: List of autoregression schemes for each target domain. maskgit, roar, or autoregressive
220
+ decoding_steps: List of number of maskgit steps for each target domain (if applicable)
221
+ token_decoding_schedules: List of maskgit token schedules for each target domain (if applicable). cosine or linear
222
+ temps: List of starting temperatures for each target domain
223
+ temp_schedules: List of temperature schedules for each target domain. linear, constant, or onex:{min_t}:{power}
224
+ cfg_scales: List of classifier-free guidance scales for each target domain
225
+ cfg_schedules: List of classifier-free guidance schedules for each target domain. constant or cosine
226
+ cfg_grow_conditioning: After every completed modality, add them to classifier-free guidance conditioning
227
+ modality_info: Dictionary with metadata for each modality, optionally used to verify that the schedule is compatible with the modality
228
+ """
229
+
230
+ # List of {target_modality, schema, number of decoded tokens, temperature, guidance_scale, cfg_cond_domains} dicts
231
+ chained_schedules = []
232
+
233
+ cond_domains = cond_domains.copy()
234
+
235
+ for target_idx in range(len(target_domains)):
236
+
237
+ scheme = autoregression_schemes[target_idx]
238
+ target_domain = target_domains[target_idx]
239
+ ntoks = tokens_per_target[target_idx]
240
+ maskgit_token_schedule_name = token_decoding_schedules[target_idx]
241
+ temp = temps[target_idx]
242
+ temp_schedule_name = temp_schedules[target_idx]
243
+ cfg_scale = cfg_scales[target_idx]
244
+ cfg_schedule_name = cfg_schedules[target_idx]
245
+
246
+ # Auto-regressive (caption, detection, ...)
247
+ if scheme == 'autoregressive':
248
+ chained_schedules.append({
249
+ 'target_domain': target_domain,
250
+ 'scheme': scheme,
251
+ 'num_tokens': None,
252
+ 'temperature': temp,
253
+ 'cfg_scale': cfg_scale,
254
+ 'cfg_cond_domains': cond_domains.copy()
255
+ })
256
+ continue
257
+
258
+ # Use modality info for (optional) assert if provided
259
+ if modality_info is not None:
260
+ assert modality_info[target_domain]['type'] not in ['seq', 'seq_token'], f'Illegal autoregressive scheme {scheme} for target domain {target_domain}'
261
+
262
+ # Token schedule
263
+ if scheme == 'maskgit':
264
+ # MaskGIT token schedule setup
265
+ num_steps = decoding_steps[target_idx]
266
+ if maskgit_token_schedule_name == 'cosine':
267
+ token_schedule = cosine_schedule(num_steps, (ntoks))
268
+ elif maskgit_token_schedule_name == 'linear':
269
+ token_schedule = linear_schedule(num_steps, (ntoks))
270
+ else:
271
+ raise ValueError(f'Illegal MaskGIT token schedule {maskgit_token_schedule_name}')
272
+ elif scheme == 'roar':
273
+ # ROAR token schedule setup (one-by-one, but random order)
274
+ num_steps = decoding_steps[target_idx]
275
+ token_schedule = linear_schedule(num_steps, ntoks)
276
+ else:
277
+ raise ValueError(f'Illegal decoding scheme {scheme}')
278
+
279
+ # Temperature schedule
280
+ if temp_schedule_name == 'linear':
281
+ temp_schedule = linear_temp_schedule(temp, token_schedule)
282
+ elif temp_schedule_name == 'constant':
283
+ temp_schedule = temp * np.ones(num_steps)
284
+ elif 'onex' in temp_schedule_name:
285
+ # onex temperature schedule has to be formatted like onex:{min_t}:{power}
286
+ min_t, power = [float(f) for f in temp_schedule_name.split(':')[1:]]
287
+ temp_schedule = onex_temp_schedule(max_t=temp, min_t=min_t, token_schedule=token_schedule, power=power)
288
+ else:
289
+ raise ValueError(f'Illegal temperature schedule {temp_schedule_name}')
290
+
291
+ # Classifier-free guidance scale schedule
292
+ if cfg_schedule_name == 'constant':
293
+ if isinstance(cfg_scale, float):
294
+ cfg_schedule = cfg_scale * np.ones(num_steps)
295
+ elif isinstance(cfg_scale, list):
296
+ cfg_schedule = np.array(cfg_scale) * np.ones(num_steps).reshape(-1, 1)
297
+ elif cfg_schedule_name == 'cosine':
298
+ raise NotImplementedError()
299
+ else:
300
+ raise ValueError(f'Illegal guidance schedule {cfg_schedule_name}')
301
+
302
+ # Concatenate schedule for this modality with previous ones
303
+ schedule = [
304
+ {
305
+ 'target_domain': target_domain,
306
+ 'scheme': scheme,
307
+ 'num_tokens': tok,
308
+ 'temperature': temp,
309
+ 'cfg_scale': cfg,
310
+ 'cfg_cond_domains': cond_domains.copy()
311
+ }
312
+ for tok, temp, cfg in zip(token_schedule, temp_schedule, cfg_schedule)
313
+ ]
314
+ chained_schedules.extend(schedule)
315
+
316
+ # Optionally add this new modality to the ones affected by classifier-free guidance
317
+ if cfg_grow_conditioning:
318
+ cond_domains.append(target_domain)
319
+
320
+ return chained_schedules
321
+
322
+
323
+ class GenerationSampler(nn.Module):
324
+ """Sampler that wraps a trained 4M model for generation use cases.
325
+ Implements standard autoregressive, MaskGIT, and ROAR generation schemes with chaining and weighted guidance."""
326
+
327
+ def __init__(self, model):
328
+ super().__init__()
329
+ self.model = model
330
+
331
+
332
+ def top_k_top_p_filtering(self, logits, top_k=0.0, top_p=0.0):
333
+ # Compatible with batching
334
+ # From https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
335
+ if top_k > 0.0:
336
+ if isinstance(top_k, int):
337
+ k = min(top_k, logits.shape[-1])
338
+ elif isinstance(top_k, float):
339
+ k = min(int(top_k * logits.shape[-1]), logits.shape[-1])
340
+ else:
341
+ raise ValueError(f"Invalid value for top_k: {top_k}")
342
+
343
+ # Remove all tokens with a probability less than the last token of the top-k
344
+ indices_to_remove = logits < torch.topk(logits, k)[0][..., -1, None]
345
+ logits[indices_to_remove] = float("-inf")
346
+
347
+ if top_p > 0.0:
348
+ sorted_logits, sorted_indices = torch.sort(logits, dim=1, descending=True)
349
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
350
+ sorted_indices_to_remove = cum_probs > top_p
351
+ # Shift the indices to the right to keep also the first token above the threshold
352
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
353
+ sorted_indices_to_remove[..., 0] = 0
354
+
355
+ restore_indices = torch.argsort(sorted_indices, dim=-1)
356
+ indices_to_remove = torch.gather(sorted_indices_to_remove, dim=-1, index=restore_indices)
357
+ logits[indices_to_remove] = float("-inf")
358
+
359
+ return logits
360
+
361
+ def sample_tokens(self, logits, temperature=1.0, top_k=0.0, top_p=0.0):
362
+ if np.isclose(temperature, 0, atol=1e-10):
363
+ samples = torch.argmax(logits, dim=-1)
364
+ # Since argmax is used, all sampled_probs will be 1 as we're selecting the max probability
365
+ sampled_probs = torch.ones_like(samples, dtype=torch.float32)
366
+ else:
367
+ filtered_logits = self.top_k_top_p_filtering(logits, top_k, top_p)
368
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
369
+ samples = torch.multinomial(probs, 1)[:, 0]
370
+ sampled_probs = probs[torch.arange(len(samples)), samples]
371
+ return samples, sampled_probs
372
+
373
+ def sample_tokens_batched(self, logits, temperature=1.0, top_k=0.0, top_p=0.0):
374
+ if logits.ndim > 2:
375
+ B, N = logits.shape[0], logits.shape[1]
376
+ logits = rearrange(logits, 'b n v -> (b n) v')
377
+ samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p)
378
+ samples = rearrange(samples, '(b n) -> b n', b=B, n=N)
379
+ sampled_probs = rearrange(sampled_probs, '(b n) -> b n', b=B, n=N)
380
+ return samples, sampled_probs
381
+ else:
382
+ return self.sample_tokens(logits, temperature, top_k, top_p)
383
+
384
+ def select_tokens(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False):
385
+ samples, sampled_probs = self.sample_tokens(logits, temperature, top_k, top_p)
386
+ top_indices = torch.topk(sampled_probs, num_select)[1]
387
+ top_samples = samples[top_indices]
388
+ if return_all_samples:
389
+ return top_samples, top_indices, samples
390
+ else:
391
+ return top_samples, top_indices
392
+
393
+ def select_tokens_batched(self, logits, num_select, temperature=1.0, top_k=0.0, top_p=0.0, return_all_samples=False):
394
+ if logits.ndim > 2:
395
+ samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k, top_p) # both of shape (B, N)
396
+ top_indices = torch.topk(sampled_probs, num_select, dim=-1)[1]
397
+ # Need to switch to gather instead of indexing here
398
+ top_samples = torch.gather(samples, dim=-1, index=top_indices)
399
+ if return_all_samples:
400
+ return top_samples, top_indices, samples
401
+ else:
402
+ return top_samples, top_indices
403
+ else:
404
+ return self.sample_tokens(logits, num_select, temperature, top_k, top_p, return_all_samples)
405
+
406
+
407
+ def forward_mask_encoder_generation(self, encoder_mod_dict):
408
+ """Modification of forward_mask_encoder adapted for generation, with support for batching
409
+ """
410
+ # Form input
411
+ B = list(encoder_mod_dict.values())[0]['tensor'].shape[0]
412
+
413
+ encoder_tokens_all, emb_all, encoder_mask_all, mod_mask_all = self.model.cat_encoder_tensors(encoder_mod_dict)
414
+ # Take max num encoder of tokens (although assuming it's the same everywhere would be better)
415
+ num_encoder_tokens = (~encoder_mask_all.reshape(B, -1)).sum(dim=1).max()
416
+
417
+ # Add arange multiplied by small constant to mask so they get sorted in a deterministic way
418
+ mask_arange = torch.arange(encoder_mask_all.shape[1], device=encoder_mask_all.device).unsqueeze(0) * 1e-6
419
+ ids_shuffle = torch.argsort(encoder_mask_all + mask_arange, dim=1)
420
+ # ids_restore = torch.argsort(ids_shuffle, dim=1)
421
+ ids_keep = ids_shuffle[:, :num_encoder_tokens]
422
+
423
+ encoder_tokens = torch.gather(encoder_tokens_all, dim=1,
424
+ index=repeat(ids_keep, "b n -> b n d", d=encoder_tokens_all.shape[2]))
425
+ encoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
426
+ encoder_mask = torch.gather(encoder_mask_all, dim=1, index=ids_keep)
427
+ mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
428
+
429
+ if self.model.num_register_tokens > 0:
430
+ prompt_tokens = repeat(self.prompt_tokens, '() n d -> b n d', b=B)
431
+ # We add prompt tokens at the beginning of the sequence
432
+ encoder_tokens = torch.cat([prompt_tokens, encoder_tokens], dim=1)
433
+ encoder_emb = torch.cat([torch.zeros_like(prompt_tokens), encoder_emb], dim=1)
434
+ encoder_mask = torch.cat([torch.zeros((B, prompt_tokens.shape[1]), dtype=torch.bool, device=encoder_mask.device), encoder_mask], dim=1)
435
+ mod_mask = torch.cat([torch.full((B, prompt_tokens.shape[1]), -1, dtype=torch.int16, device=mod_mask.device), mod_mask], dim=1)
436
+
437
+ encoder_tokens[encoder_mask] = 0.
438
+ encoder_emb[encoder_mask] = 0.
439
+ mod_mask[encoder_mask] = -1
440
+ # Mask could be of shape 'b n1 n2' but not needed for masked_fill
441
+ # This means this mask can then be re-used for decoder cross-attention
442
+ encoder_mask = rearrange(encoder_mask, 'b n2 -> b 1 n2')
443
+
444
+ return encoder_tokens, encoder_emb, encoder_mask, mod_mask
445
+
446
+
447
+ def forward_mask_decoder_maskgit(self, mod_dict, target_mod, seed=None):
448
+ """Modification of forward_mask_decoder for MaskGIT generation, with support for batching
449
+ """
450
+ if seed is not None:
451
+ torch.manual_seed(seed)
452
+ d = mod_dict[target_mod]
453
+ decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token
454
+ emb_all = d['emb']
455
+ decoder_mask_all = d['target_mask']
456
+ B = decoder_tokens_all.shape[0] # Get batch size
457
+ mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16)
458
+ mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0)
459
+ mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching
460
+ num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching / Assumes num_decoder_tokens is the same across the batch
461
+
462
+
463
+ # Add arange multiplied by small constant to mask so they get sorted in a deterministic way
464
+ mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
465
+ ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1)
466
+ # ids_restore = torch.argsort(ids_shuffle, dim=1)
467
+ ids_keep = ids_shuffle[:, :num_decoder_tokens]
468
+
469
+ decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2]))
470
+ decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
471
+ decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
472
+ mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
473
+ mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep)
474
+
475
+ decoder_tokens[decoder_mask] = 0.
476
+ decoder_emb[decoder_mask] = 0.
477
+ mod_mask[decoder_mask] = -1
478
+
479
+ return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos
480
+
481
+ def forward_mask_decoder_roar(self, mod_dict, target_mod, num_select, seed=None):
482
+ """Modification of forward_mask_decoder for ROAR generation, with support for batching
483
+ """
484
+ if seed is not None:
485
+ torch.manual_seed(seed)
486
+ d = mod_dict[target_mod]
487
+ decoder_tokens_all = torch.zeros_like(d['x']) + self.model.mask_token
488
+ emb_all = d['emb']
489
+ decoder_mask_all = d['target_mask']
490
+ B = decoder_tokens_all.shape[0] # Get batch size
491
+ mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16)
492
+ mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0)
493
+ mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B) # Added: Expansion for batching
494
+ # Only keep the first num_select tokens
495
+ num_decoder_tokens = min(num_select, (~decoder_mask_all[0]).sum()) # Adapted for batching / Assumes num_decoder_tokens is the same across the batch
496
+
497
+ # Add a small random number to the mask so they get sorted in a random way, but keeping the masked tokens first
498
+ mask_rand = torch.rand(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
499
+ ids_shuffle = torch.argsort(decoder_mask_all + mask_rand, dim=1)
500
+ # ids_restore = torch.argsort(ids_shuffle, dim=1)
501
+ # Only keep the first num_select_tokens
502
+ ids_keep = ids_shuffle[:, :num_decoder_tokens]
503
+
504
+ decoder_tokens = torch.gather(decoder_tokens_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=decoder_tokens_all.shape[2]))
505
+ decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
506
+ decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
507
+ mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
508
+ mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep)
509
+
510
+ decoder_tokens[decoder_mask] = 0.
511
+ decoder_emb[decoder_mask] = 0.
512
+ mod_mask[decoder_mask] = -1
513
+
514
+ return decoder_tokens, decoder_emb, decoder_mask, mod_mask, mod_pos
515
+
516
+ def forward_mask_decoder_autoregressive(self, mod_dict, target_mod, seed=None):
517
+ # Adapted for batching
518
+ if seed is not None:
519
+ torch.manual_seed(seed)
520
+ # This is the concatenation part
521
+ d = mod_dict[target_mod]
522
+ decoder_ids_all = d['ids']
523
+ emb_all = d['emb']
524
+ decoder_mask_all = d['target_mask']
525
+ B = decoder_ids_all.shape[0] # Get batch size
526
+ mod_mask_all = torch.full_like(d['ids'], self.model.modality_info[target_mod]['id'], dtype=torch.int16)
527
+ mod_pos_all = torch.arange(d['x'].shape[1], device=d['x'].device).unsqueeze(0)
528
+ mod_pos_all = repeat(mod_pos_all, '1 n -> b n', b=B)
529
+ num_decoder_tokens = (~decoder_mask_all[0]).sum() # Adapted for batching, but assumes num_decoder_tokens is the same across the batch
530
+
531
+ # Add arange multiplied by small constant to mask so they get sorted in a deterministic way
532
+ mask_arange = torch.arange(decoder_mask_all.shape[1], device=decoder_mask_all.device).unsqueeze(0) * 1e-6
533
+ ids_shuffle = torch.argsort(decoder_mask_all + mask_arange, dim=1)
534
+ # ids_restore = torch.argsort(ids_shuffle, dim=1)
535
+ ids_keep = ids_shuffle[:, :num_decoder_tokens]
536
+
537
+ # Same as in forward_mask_decoder
538
+ decoder_ids = torch.gather(decoder_ids_all, dim=1, index=ids_keep)
539
+ decoder_emb = torch.gather(emb_all, dim=1, index=repeat(ids_keep, "b n -> b n d", d=emb_all.shape[2]))
540
+ decoder_mask = torch.gather(decoder_mask_all, dim=1, index=ids_keep)
541
+ mod_mask = torch.gather(mod_mask_all, dim=1, index=ids_keep)
542
+ mod_pos = torch.gather(mod_pos_all, dim=1, index=ids_keep)
543
+
544
+ decoder_ids[decoder_mask] = 0
545
+ decoder_emb[decoder_mask] = 0.
546
+ mod_mask[decoder_mask] = -1
547
+
548
+ return decoder_ids, decoder_emb, decoder_mask, mod_mask, mod_pos
549
+
550
+ def merge_sequences(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"):
551
+ device = mod_dict[target_mod]['tensor'].device
552
+ # Get input ids
553
+ input_ids = mod_dict[target_mod]['tensor'].squeeze().detach().cpu()
554
+ input_ids = input_ids[mod_dict[target_mod]['input_mask'].squeeze().detach().cpu() == 0]
555
+ input_ids = input_ids.tolist()
556
+
557
+ if len(input_ids) == 0:
558
+ input_ids = [text_tokenizer.get_vocab()[default_sentinel]]
559
+
560
+ # Get predicted ids
561
+ pred_ids = pred_ids.squeeze().detach().cpu().tolist()
562
+ if isinstance(pred_ids, int):
563
+ pred_ids = [pred_ids]
564
+
565
+ # Get sentinel ids using the tokenizer
566
+ sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values())
567
+ # Perform merging
568
+ merged_ids = merge_span_masking(input_ids, pred_ids, sentinel_ids)
569
+ merged_ids = torch.tensor(merged_ids).unsqueeze(0)
570
+ # Create new dict
571
+ new_input_mask = torch.zeros_like(merged_ids, dtype=torch.bool)
572
+ new_target_mask = torch.ones_like(merged_ids, dtype=torch.bool)
573
+ new_dict = {'tensor': merged_ids.to(device),
574
+ 'input_mask': new_input_mask.to(device),
575
+ 'target_mask': new_target_mask.to(device)}
576
+ new_dict['decoder_attention_mask'] = torch.zeros_like(new_target_mask, dtype=torch.bool)
577
+
578
+ mod_dict[target_mod] = new_dict
579
+ return mod_dict
580
+
581
+ def merge_sequences_batched(self, mod_dict, pred_ids, target_mod, text_tokenizer, default_sentinel="[S_1]"):
582
+ # Unbatches and calls merge sequence per batch, then regroups it into a batch
583
+
584
+ pad_id = text_tokenizer.token_to_id("[PAD]")
585
+
586
+ B = mod_dict[target_mod]['tensor'].shape[0]
587
+ device = mod_dict[target_mod]['tensor'].device
588
+
589
+ tensors = torch.split(mod_dict[target_mod]['tensor'], 1)
590
+ input_masks = torch.split(mod_dict[target_mod]['input_mask'], 1)
591
+ pred_ids = torch.split(pred_ids, 1)
592
+
593
+ input_dicts = []
594
+ for t, im in zip(tensors, input_masks):
595
+ d = {target_mod: {'tensor': t, 'input_mask': im}}
596
+ input_dicts.append(d)
597
+
598
+ merged_tensors = []
599
+ merged_input_masks = []
600
+ merged_target_masks = []
601
+ merged_seq_lens = []
602
+ for input_d, pi in zip(input_dicts, pred_ids):
603
+ # Output of merge_sequences is mod_dict with modified target mod
604
+ merged_d = self.merge_sequences(input_d, pi, target_mod, text_tokenizer, default_sentinel)[target_mod]
605
+ merged_tensors.append(merged_d['tensor'])
606
+ merged_input_masks.append(merged_d['input_mask'])
607
+ merged_target_masks.append(merged_d['input_mask'])
608
+ merged_seq_lens.append(merged_d['tensor'].shape[1])
609
+
610
+
611
+ max_seq_len = max(merged_seq_lens)
612
+
613
+ for i in range(len(merged_tensors)):
614
+ # Right pad all tensors
615
+ p1d = (0, max_seq_len - merged_seq_lens[i])
616
+ merged_tensors[i] = F.pad(merged_tensors[i], p1d, "constant",pad_id)
617
+ merged_input_masks[i] = F.pad(merged_input_masks[i], p1d, "constant", True)
618
+ merged_target_masks[i] = F.pad(merged_target_masks[i], p1d, "constant", True)
619
+
620
+ new_dict = {'tensor': torch.cat(merged_tensors, dim=0).to(device),
621
+ 'input_mask': torch.cat(merged_input_masks, dim=0).to(device),
622
+ 'target_mask': torch.cat(merged_target_masks, dim=0).to(device)}
623
+ new_dict['decoder_attention_mask'] = torch.zeros_like(new_dict['target_mask'], dtype=torch.bool)
624
+
625
+ mod_dict[target_mod] = new_dict
626
+ return mod_dict
627
+
628
+ def forward_enc_dec_maskgit_batched(self, mod_dict, target_mod, seed=None):
629
+ # Encoder
630
+ encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
631
+ for mod, d in mod_dict.items()
632
+ if mod in self.model.encoder_embeddings}
633
+ encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
634
+ x = encoder_tokens + encoder_emb
635
+ x = self.model.forward_encoder(x, encoder_mask)
636
+
637
+ # Decoder
638
+ context = self.model.decoder_proj_context(x) + encoder_emb
639
+ decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
640
+ decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_maskgit(decoder_mod_dict, target_mod, seed=seed)
641
+ y = decoder_tokens + decoder_emb
642
+ y = self.model.forward_decoder(y, context, encoder_mask, None)
643
+ B, N, D = y.shape
644
+ logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod]
645
+ logits = logits.reshape(B, N, -1)
646
+
647
+
648
+ return logits, mod_pos
649
+
650
+ def maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None):
651
+ logits, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed)
652
+
653
+ # MaskGIT sampling
654
+ top_samples, top_indices = self.select_tokens_batched(logits, num_select,
655
+ temperature=temperature, top_k=top_k, top_p=top_p)
656
+ # Update mod dict
657
+ # We rely on gather / scatter for batched operations
658
+ top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select)
659
+ mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples)
660
+ mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
661
+ mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
662
+
663
+ return mod_dict
664
+
665
+ def guided_maskgit_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p,
666
+ conditioning=[], guidance_scale=1.0, seed=None, write_all_predictions=False):
667
+
668
+ ### 1 - First pass, with conditioning
669
+ logits_cond, _ = self.forward_enc_dec_maskgit_batched(mod_dict, target_mod, seed=seed)
670
+
671
+ ### 2 - Second pass, without conditioning
672
+ mod_dict_uncond = copy.deepcopy(mod_dict)
673
+ for mod in conditioning:
674
+ if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']:
675
+ mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod)
676
+ elif self.model.modality_info[mod]['type'] in ['seq_emb']:
677
+ mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod)
678
+ else:
679
+ mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod)
680
+
681
+ logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(mod_dict_uncond, target_mod, seed=seed)
682
+
683
+ ### 3 - Classifier-free guidance
684
+ logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale
685
+
686
+ ### 4 - MaskGIT sampling
687
+ top_samples, top_indices, all_samples = self.select_tokens_batched(
688
+ logits, num_select,
689
+ temperature=temperature, top_k=top_k, top_p=top_p,
690
+ return_all_samples=True
691
+ )
692
+
693
+ ### 5 - Update mod dict
694
+ # We rely on gather / scatter for batched operations
695
+ top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select)
696
+ if write_all_predictions:
697
+ mod_dict[target_mod]['tensor'][:, mod_pos] = all_samples
698
+ else:
699
+ mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, top_pos, top_samples)
700
+ mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
701
+ mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
702
+
703
+ return mod_dict
704
+
705
+ def multi_guided_maskgit_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod, num_select,
706
+ temperature, top_k, top_p, seed=None, write_all_predictions=False):
707
+
708
+ ### 1 - Conditional forward passes (one for each guided condition)
709
+ logits_cond_all = []
710
+ for cond_dict in cond_dicts:
711
+ logits_cond_i, _ = self.forward_enc_dec_maskgit_batched(cond_dict, target_mod, seed=seed)
712
+ logits_cond_all.append(logits_cond_i)
713
+
714
+ ### 2 - Unconditional forward pass
715
+ logits_uncond, mod_pos = self.forward_enc_dec_maskgit_batched(uncond_dict, target_mod, seed=seed)
716
+
717
+ ### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)}
718
+ # See https://arxiv.org/abs/2206.01714
719
+ logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0)
720
+
721
+ ### 4 - MaskGIT sampling
722
+ top_samples, top_indices, all_samples = self.select_tokens_batched(
723
+ logits, num_select,
724
+ temperature=temperature, top_k=top_k, top_p=top_p,
725
+ return_all_samples=True
726
+ )
727
+
728
+ ### 5 - Update mod dict with newly generated tokens
729
+ # We rely on gather / scatter for batched operations
730
+ top_pos = torch.gather(mod_pos, -1, top_indices) # (B, num_select)
731
+ if write_all_predictions:
732
+ uncond_dict[target_mod]['tensor'][:, mod_pos] = all_samples
733
+ else:
734
+ uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, top_pos, top_samples)
735
+ uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
736
+ uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
737
+ # Update conditioning dicts
738
+ for i in range(len(cond_dicts)):
739
+ cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, top_pos, top_samples)
740
+ cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, top_pos, torch.zeros_like(top_samples, dtype=torch.bool))
741
+ cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, top_pos, torch.ones_like(top_samples, dtype=torch.bool))
742
+
743
+ return uncond_dict, cond_dicts
744
+
745
+ def forward_enc_dec_roar_batched(self, mod_dict, target_mod, num_select, seed=None):
746
+ # Encoder
747
+ encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
748
+ for mod, d in mod_dict.items()
749
+ if mod in self.model.encoder_embeddings}
750
+ encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
751
+ x = encoder_tokens + encoder_emb
752
+ x = self.model.forward_encoder(x, encoder_mask)
753
+
754
+ # Decoder
755
+ context = self.model.decoder_proj_context(x) + encoder_emb
756
+ decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
757
+ decoder_tokens, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_roar(decoder_mod_dict, target_mod, num_select, seed=seed)
758
+ y = decoder_tokens + decoder_emb
759
+ y = self.model.forward_decoder(y, context, encoder_mask, None)
760
+ B, N, D = y.shape
761
+ logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask)[target_mod]
762
+ logits = logits.reshape(B, N, -1)
763
+
764
+ return logits, mod_pos
765
+
766
+ def roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p, seed=None):
767
+ """ROAR = Random Order Autoregression"""
768
+
769
+ logits, mod_pos = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed)
770
+
771
+ # Simple sampling
772
+ samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p)
773
+
774
+ # Update mod dict
775
+ # We rely on scatter for batched operations
776
+ select_pos = mod_pos
777
+ mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples)
778
+ mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
779
+ mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
780
+
781
+ return mod_dict
782
+
783
+ def guided_roar_step_batched(self, mod_dict, target_mod, num_select, temperature, top_k, top_p,
784
+ conditioning=[], guidance_scale=1.0, seed=None):
785
+ """ROAR = Random Order Autoregression"""
786
+
787
+ ### 1 - First pass, with conditioning
788
+ logits_cond, _ = self.forward_enc_dec_roar_batched(mod_dict, target_mod, num_select, seed=seed)
789
+
790
+ ### 2 - Second pass, without conditioning
791
+ mod_dict_uncond = copy.deepcopy(mod_dict)
792
+ for mod in conditioning:
793
+ if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']:
794
+ mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod)
795
+ elif self.model.modality_info[mod]['type'] in ['seq_emb']:
796
+ mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod)
797
+ else:
798
+ mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod)
799
+
800
+ logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(mod_dict_uncond, target_mod, num_select, seed=seed)
801
+
802
+ ### 3 - Classifier-free guidance
803
+ logits = logits_uncond + (logits_cond - logits_uncond) * guidance_scale
804
+
805
+ ### 4 - Simple sampling
806
+ samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p)
807
+
808
+ ### 5 - Update mod dict
809
+ # We rely on gather / scatter for batched operations
810
+ select_pos = mod_pos
811
+ mod_dict[target_mod]['tensor'] = torch.scatter(mod_dict[target_mod]['tensor'], -1, select_pos, samples)
812
+ mod_dict[target_mod]['input_mask'] = torch.scatter(mod_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
813
+ mod_dict[target_mod]['target_mask'] = torch.scatter(mod_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
814
+
815
+ return mod_dict
816
+
817
+ def multi_guided_roar_step_batched(self, uncond_dict, cond_dicts, cond_weights, target_mod,
818
+ num_select, temperature, top_k, top_p, seed=None):
819
+
820
+ ### 1 - Conditional forward passes (one for each guided condition)
821
+ logits_cond_all = []
822
+ for cond_dict in cond_dicts:
823
+ logits_cond_i, _ = self.forward_enc_dec_roar_batched(cond_dict, target_mod, num_select, seed=seed)
824
+ logits_cond_all.append(logits_cond_i)
825
+
826
+ ### 2 - Unconditional forward pass
827
+ logits_uncond, mod_pos = self.forward_enc_dec_roar_batched(uncond_dict, target_mod, num_select, seed=seed)
828
+
829
+ ### 3 Conjunction of multiple conditions: l_uncond + sum_i{w_i * (l_cond_i - l_uncond)}
830
+ # See https://arxiv.org/abs/2206.01714
831
+ logits = logits_uncond + torch.stack([w * (logits_cond - logits_uncond) for w, logits_cond in zip(cond_weights, logits_cond_all)]).sum(dim=0)
832
+
833
+ ### 4 - Simple sampling
834
+ samples, sampled_probs = self.sample_tokens_batched(logits, temperature, top_k=top_k, top_p=top_p)
835
+
836
+ ### 5 - Update mod dict
837
+ # We rely on gather / scatter for batched operations
838
+ select_pos = mod_pos
839
+ uncond_dict[target_mod]['tensor'] = torch.scatter(uncond_dict[target_mod]['tensor'], -1, select_pos, samples)
840
+ uncond_dict[target_mod]['input_mask'] = torch.scatter(uncond_dict[target_mod]['input_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
841
+ uncond_dict[target_mod]['target_mask'] = torch.scatter(uncond_dict[target_mod]['target_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
842
+ # Update conditioning dicts
843
+ for i in range(len(cond_dicts)):
844
+ cond_dicts[i][target_mod]['tensor'] = torch.scatter(cond_dicts[i][target_mod]['tensor'], -1, select_pos, samples)
845
+ cond_dicts[i][target_mod]['input_mask'] = torch.scatter(cond_dicts[i][target_mod]['input_mask'], -1, select_pos, torch.ones_like(samples, dtype=torch.bool))
846
+ cond_dicts[i][target_mod]['target_mask'] = torch.scatter(cond_dicts[i][target_mod]['target_mask'], -1, select_pos, torch.zeros_like(samples, dtype=torch.bool))
847
+
848
+ return uncond_dict, cond_dicts
849
+
850
+ def autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float,
851
+ use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None, seed=None):
852
+
853
+ # Encoder
854
+ encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
855
+ for mod, d in mod_dict.items()
856
+ if mod in self.model.encoder_embeddings}
857
+ encoder_tokens, encoder_emb, encoder_mask, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
858
+ x = encoder_tokens + encoder_emb
859
+ x = self.model.forward_encoder(x, encoder_mask) # B, N, D
860
+
861
+ # Get batch size
862
+ B = x.shape[0]
863
+
864
+ # Decoder
865
+ context = self.model.decoder_proj_context(x) + encoder_emb
866
+ decoder_mod_dict = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
867
+ decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict, target_mod, seed=seed)
868
+ device = decoder_ids.device
869
+ seq_len = self.model.modality_info[target_mod]['max_tokens']
870
+
871
+ if use_eos and eos_token is None:
872
+ # The eos_token is the final sentinel token provided
873
+ eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all
874
+ if use_eos:
875
+ eos_token = eos_token.to(device)
876
+
877
+ # If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token)
878
+ out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device)
879
+ # Set decoder_tokens to None, we do not use them for decoding
880
+ decoder_ids = None
881
+
882
+ # If all samples of the batch have eos, return early
883
+ if use_eos and (out == eos_token).any(dim=-1).all():
884
+ return out
885
+
886
+ y_emb = decoder_emb[:, :seq_len]
887
+ seq_len = y_emb.shape[1]
888
+
889
+ # Auto-regressive decoding and sampling
890
+ for i in range(seq_len):
891
+ cur_len = out.shape[1]
892
+ # Convert ids into word embeddings and add corresponding posembs + modemb
893
+ y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len]
894
+ # Build causal mask
895
+ causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1)
896
+ causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B)
897
+
898
+ y = self.model.forward_decoder(y, context, encoder_mask, causal_mask)
899
+ logits = self.model.forward_logits(y, decoder_mod_dict, decoder_mod_mask[:, :cur_len])[target_mod]
900
+ logits = rearrange(logits, "(b n) d -> b n d", b=B, n=cur_len)
901
+ last_logits = logits[:, -1]
902
+
903
+ # Sample token for the newly generated logit
904
+ if np.isclose(temperature, 0, atol=1e-10):
905
+ sample = torch.argmax(last_logits, dim=-1, keepdim=True)
906
+ else:
907
+ filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p)
908
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
909
+ sample = torch.multinomial(probs, 1)
910
+ out = torch.cat((out, sample), dim=-1)
911
+
912
+ if use_eos and (out == eos_token).any(dim=-1).all():
913
+ break
914
+
915
+ mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer)
916
+
917
+ return mod_dict
918
+
919
+ def guided_autoregressive_step_batched(self, mod_dict, target_mod, temperature, top_k: Union[float, int], top_p: float,
920
+ use_eos=True, eos_token=None, start_tokens=None, text_tokenizer=None,
921
+ conditioning=[], guidance_scale=1.0, seed=None):
922
+
923
+ ### 1 - Encoder forward pass, with conditioning
924
+
925
+ # Encoder
926
+ encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
927
+ for mod, d in mod_dict.items()
928
+ if mod in self.model.encoder_embeddings}
929
+ encoder_tokens, encoder_emb, encoder_mask_cond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
930
+ x = encoder_tokens + encoder_emb
931
+ x = self.model.forward_encoder(x, encoder_mask_cond) # B, N, D
932
+
933
+ # Get batch size
934
+ B = x.shape[0]
935
+
936
+ # Decoder
937
+ context_cond = self.model.decoder_proj_context(x) + encoder_emb
938
+ decoder_mod_dict_cond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
939
+ decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_cond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_cond, target_mod, seed=seed)
940
+ device = decoder_ids.device
941
+ seq_len = self.model.modality_info[target_mod]['max_tokens']
942
+
943
+
944
+ ### 2 - Encoder forward pass, without conditioning
945
+
946
+ mod_dict_uncond = copy.deepcopy(mod_dict)
947
+ for mod in conditioning:
948
+ if self.model.modality_info[mod]['type'] in ['seq', 'seq_token']:
949
+ mod_dict_uncond = empty_seq_modality(mod_dict_uncond, mod)
950
+ elif self.model.modality_info[mod]['type'] in ['seq_emb']:
951
+ mod_dict_uncond = empty_seq_emb_modality(mod_dict_uncond, mod)
952
+ else:
953
+ mod_dict_uncond = empty_img_modality(mod_dict_uncond, mod)
954
+
955
+ # Encoder
956
+ encoder_mod_dict = {mod: self.model.encoder_embeddings[mod](d)
957
+ for mod, d in mod_dict_uncond.items()
958
+ if mod in self.model.encoder_embeddings}
959
+ encoder_tokens, encoder_emb, encoder_mask_uncond, encoder_mod_mask = self.forward_mask_encoder_generation(encoder_mod_dict)
960
+ x = encoder_tokens + encoder_emb
961
+ x = self.model.forward_encoder(x, encoder_mask_uncond) # B, N, D
962
+
963
+ # Decoder
964
+ context_uncond = self.model.decoder_proj_context(x) + encoder_emb
965
+ decoder_mod_dict_uncond = {target_mod: self.model.decoder_embeddings[target_mod].forward_embed(mod_dict[target_mod])}
966
+ decoder_ids, decoder_emb, decoder_mask, decoder_mod_mask_uncond, mod_pos = self.forward_mask_decoder_autoregressive(decoder_mod_dict_uncond, target_mod, seed=seed)
967
+
968
+
969
+ if use_eos and eos_token is None:
970
+ # The eos_token is the final sentinel token provided
971
+ eos_token = decoder_ids[0][decoder_mask[0] == 0][-1] # Assumes the EOS token is the same for all
972
+ if use_eos:
973
+ eos_token = eos_token.to(device)
974
+
975
+ # If no start_tokens, just use the beginning of the actual target (i.e., a sentinel token)
976
+ out = decoder_ids[:, :1] if start_tokens is None else start_tokens.to(device)
977
+ # Set decoder_tokens to None, we do not use them for decoding
978
+ decoder_ids = None
979
+
980
+ # If all samples of the batch have eos, return early
981
+ if use_eos and (out == eos_token).any(dim=-1).all():
982
+ return out
983
+
984
+ y_emb = decoder_emb[:, :seq_len]
985
+ seq_len = y_emb.shape[1]
986
+
987
+ ### 3 - Auto-regressive decoding and sampling
988
+ for i in range(seq_len):
989
+ cur_len = out.shape[1]
990
+ # Convert ids into word embeddings and add corresponding posembs + modemb
991
+ y = self.model.decoder_embeddings[target_mod].token_emb(out) + y_emb[:, :cur_len]
992
+ # Build causal mask
993
+ causal_mask = torch.ones((cur_len, cur_len), dtype=torch.bool, device=y.device).triu(1)
994
+ causal_mask = repeat(causal_mask, "n1 n2 -> b n1 n2", b=B)
995
+
996
+ ### 3a - Decoder forward pass, with conditioning
997
+ y_cond = self.model.forward_decoder(y, context_cond, encoder_mask_cond, causal_mask)
998
+ logits_cond = self.model.forward_logits(y_cond, decoder_mod_dict_cond, decoder_mod_mask_cond[:, :cur_len])[target_mod]
999
+ logits_cond = rearrange(logits_cond, "(b n) d -> b n d", b=B, n=cur_len)
1000
+ last_logits_cond = logits_cond[:, -1]
1001
+
1002
+ ### 3b - Decoder forward pass, without conditioning
1003
+ y_uncond = self.model.forward_decoder(y, context_uncond, encoder_mask_uncond, causal_mask)
1004
+ logits_uncond = self.model.forward_logits(y_uncond, decoder_mod_dict_uncond, decoder_mod_mask_uncond[:, :cur_len])[target_mod]
1005
+ logits_uncond = rearrange(logits_uncond, "(b n) d -> b n d", b=B, n=cur_len)
1006
+ last_logits_uncond = logits_uncond[:, -1]
1007
+
1008
+ ### 3c - Classifier-free guidance
1009
+ last_logits = last_logits_uncond + (last_logits_cond - last_logits_uncond) * guidance_scale
1010
+
1011
+ # Sample token for the newly generated logit
1012
+ if np.isclose(temperature, 0, atol=1e-10):
1013
+ sample = torch.argmax(last_logits, dim=-1, keepdim=True)
1014
+ else:
1015
+ filtered_logits = self.top_k_top_p_filtering(last_logits, top_k, top_p)
1016
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
1017
+ sample = torch.multinomial(probs, 1)
1018
+ out = torch.cat((out, sample), dim=-1)
1019
+
1020
+ if use_eos and (out == eos_token).any(dim=-1).all():
1021
+ break
1022
+
1023
+ mod_dict = self.merge_sequences_batched(mod_dict, out, target_mod, text_tokenizer)
1024
+
1025
+ return mod_dict
1026
+
1027
+
1028
+ @torch.no_grad()
1029
+ def generate(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None):
1030
+ """ Generates a sequence of tokens from the input modalities.
1031
+ :param mod_dict: Dictionary of modalities.
1032
+ :param schedule: Schedule of modalities to use.
1033
+ List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}.
1034
+ :param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering).
1035
+ :param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering).
1036
+ :param text_tokenizer: Text tokenizer.
1037
+ :param verbose: Whether to print progress.
1038
+ :param seed: Random seed.
1039
+ :return: Generated mod dict.
1040
+ """
1041
+
1042
+ # Input embedding -> tokenizes the modalities - Many are placeholder for now
1043
+ mod_dict = copy.deepcopy(mod_dict)
1044
+
1045
+ for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose):
1046
+ target_mod = schedule_step_info['target_domain']
1047
+ temp = schedule_step_info['temperature']
1048
+ cfg_scale = schedule_step_info.get('cfg_scale', 1.0)
1049
+ cfg_conditioning = schedule_step_info.get('cfg_cond_domains', [])
1050
+ seed_i = seed + step if seed is not None else None
1051
+
1052
+ if self.model.modality_info[target_mod]['type'] == 'img':
1053
+ scheme = schedule_step_info['scheme']
1054
+ num_select = schedule_step_info['num_tokens']
1055
+
1056
+ if scheme.lower() == 'maskgit':
1057
+ if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
1058
+ mod_dict = self.maskgit_step_batched(
1059
+ mod_dict, target_mod, num_select, temperature=temp,
1060
+ top_k=top_k, top_p=top_p, seed=seed_i
1061
+ )
1062
+ else:
1063
+ mod_dict = self.guided_maskgit_step_batched(
1064
+ mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
1065
+ conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i
1066
+ )
1067
+ elif scheme.lower() == 'roar':
1068
+ if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
1069
+ mod_dict = self.roar_step_batched(
1070
+ mod_dict, target_mod, num_select, temperature=temp,
1071
+ top_k=top_k, top_p=top_p, seed=seed_i
1072
+ )
1073
+ else:
1074
+ mod_dict = self.guided_roar_step_batched(
1075
+ mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
1076
+ conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i
1077
+ )
1078
+ else:
1079
+ raise ValueError("Invalid sampling scheme")
1080
+ elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']:
1081
+ if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
1082
+ mod_dict = self.autoregressive_step_batched(
1083
+ mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
1084
+ text_tokenizer=text_tokenizer, seed=seed_i
1085
+ )
1086
+ else:
1087
+ mod_dict = self.guided_autoregressive_step_batched(
1088
+ mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
1089
+ text_tokenizer=text_tokenizer, conditioning=cfg_conditioning,
1090
+ guidance_scale=cfg_scale, seed=seed_i
1091
+ )
1092
+ else:
1093
+ raise ValueError("Invalid schedule")
1094
+
1095
+ return mod_dict
1096
+
1097
+
1098
+ @torch.no_grad()
1099
+ def generate_iter(self, mod_dict, schedule, top_k=0.0, top_p=0.0, text_tokenizer=None, verbose=False, seed=None):
1100
+ """ Iterator that generates a sequence of tokens from the input modalities step by step.
1101
+ :param mod_dict: Dictionary of modalities.
1102
+ :param schedule: Schedule of modalities to use.
1103
+ List of dictionaries containing {target_domain, scheme, num_tokens, temperature, cfg_scale, cfg_cond_domains}.
1104
+ :param top_k: top_k > 0: Keep only top k tokens with highest probability (a.k.a. top-k filtering).
1105
+ :param top_p: top_p > 0.0: Keep the top tokens with cumulative probability >= top_p (a.k.a. nucleus filtering).
1106
+ :param text_tokenizer: Text tokenizer.
1107
+ :param verbose: Whether to print progress.
1108
+ :param seed: Random seed.
1109
+ :return: Iterator of generated mod dict.
1110
+ """
1111
+
1112
+ # Input embedding -> tokenizes the modalities - Many are placeholder for now
1113
+ mod_dict = copy.deepcopy(mod_dict)
1114
+
1115
+ for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose):
1116
+ target_mod = schedule_step_info['target_domain']
1117
+ temp = schedule_step_info['temperature']
1118
+ cfg_scale = schedule_step_info.get('cfg_scale', 1.0)
1119
+ cfg_conditioning = schedule_step_info.get('cfg_cond_domains', [])
1120
+ seed_i = seed + step if seed is not None else None
1121
+
1122
+ if self.model.modality_info[target_mod]['type'] == 'img':
1123
+ scheme = schedule_step_info['scheme']
1124
+ num_select = schedule_step_info['num_tokens']
1125
+
1126
+ if scheme.lower() == 'maskgit':
1127
+ if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
1128
+ mod_dict = self.maskgit_step_batched(
1129
+ mod_dict, target_mod, num_select, temperature=temp,
1130
+ top_k=top_k, top_p=top_p, seed=seed_i
1131
+ )
1132
+ else:
1133
+ mod_dict = self.guided_maskgit_step_batched(
1134
+ mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
1135
+ conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i,
1136
+ write_all_predictions=True
1137
+ )
1138
+ elif scheme.lower() == 'roar':
1139
+ if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
1140
+ mod_dict = self.roar_step_batched(
1141
+ mod_dict, target_mod, num_select, temperature=temp,
1142
+ top_k=top_k, top_p=top_p, seed=seed_i
1143
+ )
1144
+ else:
1145
+ mod_dict = self.guided_roar_step_batched(
1146
+ mod_dict, target_mod, num_select, temperature=temp, top_k=top_k, top_p=top_p,
1147
+ conditioning=cfg_conditioning, guidance_scale=cfg_scale, seed=seed_i
1148
+ )
1149
+ else:
1150
+ raise ValueError("Invalid sampling scheme")
1151
+ elif self.model.modality_info[target_mod]['type'] in ['seq', 'seq_token']:
1152
+ if cfg_scale == 1.0 or len(cfg_conditioning) == 0:
1153
+ mod_dict = self.autoregressive_step_batched(
1154
+ mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
1155
+ text_tokenizer=text_tokenizer, seed=seed_i
1156
+ )
1157
+ else:
1158
+ mod_dict = self.guided_autoregressive_step_batched(
1159
+ mod_dict, target_mod, temperature=temp, top_k=top_k, top_p=top_p,
1160
+ text_tokenizer=text_tokenizer, conditioning=cfg_conditioning,
1161
+ guidance_scale=cfg_scale, seed=seed_i
1162
+ )
1163
+ else:
1164
+ raise ValueError("Invalid schedule")
1165
+
1166
+ yield mod_dict
1167
+
1168
+ @torch.no_grad()
1169
+ def generate_multi_guided(self, uncond_dict, cond_dicts, schedule, top_k=0.0, top_p=0.0,
1170
+ text_tokenizer=None, verbose=False, seed=None):
1171
+ # Generation function for multiple weighted conditions
1172
+
1173
+ # To detect when a modality has finished generating, we keep track of the current target modality
1174
+ cur_target_mod = schedule[0]['target_domain']
1175
+
1176
+ uncond_dict = copy.deepcopy(uncond_dict)
1177
+ cond_dicts = copy.deepcopy(cond_dicts)
1178
+
1179
+ # Add the to-be-generated modality to the conditional dicts
1180
+ for i in range(len(cond_dicts)):
1181
+ cond_dicts[i][cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod])
1182
+
1183
+ for step, schedule_step_info in tqdm(enumerate(schedule), disable=not verbose):
1184
+ target_mod = schedule_step_info['target_domain']
1185
+ temp = schedule_step_info['temperature']
1186
+ num_select = schedule_step_info['num_tokens']
1187
+ cond_weights = schedule_step_info['cfg_scale']
1188
+
1189
+ # Once a modality is fully generated, add it as a new condition
1190
+ if cur_target_mod != target_mod:
1191
+ for i in range(len(cond_dicts)):
1192
+ # Remove the previously generated modality from the conditionings
1193
+ del cond_dicts[i][cur_target_mod]
1194
+ # Add the next modality to be generated to the conditionings
1195
+ cond_dicts[i][target_mod] = copy.deepcopy(uncond_dict[target_mod])
1196
+
1197
+ # Remove the fully generated modality from the unconditional dict inputs
1198
+ uncond_dict[cur_target_mod]['input_mask'][:] = True
1199
+
1200
+ # Add the previously generated modality as an additional condition
1201
+ new_cond = {}
1202
+ new_cond[cur_target_mod] = copy.deepcopy(uncond_dict[cur_target_mod])
1203
+ new_cond[cur_target_mod]['input_mask'][:] = False
1204
+ new_cond[cur_target_mod]['target_mask'][:] = True
1205
+ new_cond[target_mod] = copy.deepcopy(uncond_dict[target_mod])
1206
+ cond_dicts.append(new_cond)
1207
+
1208
+ cur_target_mod = target_mod
1209
+
1210
+ if self.model.modality_info[target_mod]['type'] == 'img':
1211
+ scheme = schedule_step_info['scheme']
1212
+
1213
+ if scheme.lower() == 'maskgit':
1214
+ uncond_dict, cond_dicts = self.multi_guided_maskgit_step_batched(
1215
+ uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed
1216
+ )
1217
+ elif scheme.lower() == 'roar':
1218
+ uncond_dict, cond_dicts = self.multi_guided_roar_step_batched(
1219
+ uncond_dict, cond_dicts, cond_weights, target_mod, num_select, temp, top_k, top_p, seed=seed
1220
+ )
1221
+ else:
1222
+ raise ValueError("Invalid sampling scheme")
1223
+
1224
+ else:
1225
+ raise NotImplementedError("Only image modalities are supported for now")
1226
+
1227
+ return uncond_dict
1228
+
1229
+ @torch.no_grad()
1230
+ def generate_sam_dense(self, mod_dict, schedule, text_tokenizer, batch_size=16,
1231
+ key='sam_instance', top_k=0.0, top_p=0.0, seed=None, verbose=False):
1232
+ # Generation function for dense SAM instance prediction
1233
+
1234
+ device = mod_dict[list(mod_dict.keys())[0]]['tensor'].device
1235
+ mod_dict = copy.deepcopy(mod_dict)
1236
+ # Repeat the input batch to match the batch size
1237
+ expanded_batch = expand_to_batch(copy.deepcopy(mod_dict), batch_size=batch_size)
1238
+
1239
+ # Filter the schedule to only include the key domain
1240
+ schedule = [s for s in schedule if s['target_domain'] == key]
1241
+
1242
+ out_dict = self.generate(
1243
+ expanded_batch, schedule, text_tokenizer=text_tokenizer,
1244
+ verbose=verbose, seed=seed,
1245
+ top_p=top_p, top_k=top_k,
1246
+ )
1247
+
1248
+ # Merge the batch generated sequences into one sequence
1249
+ sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values())
1250
+ merged_seq = []
1251
+
1252
+ for i in range(batch_size):
1253
+
1254
+ input_seq = out_dict[key]['tensor'][i]
1255
+ input_seq = input_seq[out_dict[key]['input_mask'][i] == 0]
1256
+ input_seq = input_seq.tolist()
1257
+
1258
+ target_seq = out_dict[key]['tensor'][i]
1259
+ target_seq = target_seq[out_dict[key]['target_mask'][i] == 0]
1260
+ target_seq = target_seq.tolist()
1261
+
1262
+ merged_seq.extend(merge_span_masking(input_seq, target_seq, sentinel_ids=sentinel_ids))
1263
+
1264
+ merged_seq = torch.tensor(merged_seq, device=device).unsqueeze(0)
1265
+
1266
+ mod_dict[key] = {
1267
+ 'tensor': merged_seq,
1268
+ 'input_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device),
1269
+ 'target_mask': torch.ones(merged_seq.shape, dtype=torch.bool, device=device),
1270
+ 'decoder_attention_mask': torch.zeros(merged_seq.shape, dtype=torch.bool, device=device),
1271
+ }
1272
+
1273
+ return mod_dict
fourm/models/lora_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Set, Optional, Type
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ SELF_ATTENTION_MODULES = {'Attention', 'NormAttention'}
21
+ CROSS_ATTENTION_MODULES = {'CrossAttention', 'NormCrossAttention'}
22
+ ATTENTION_MODULES = SELF_ATTENTION_MODULES | CROSS_ATTENTION_MODULES
23
+ MLP_MODULES = {'Mlp', 'GatedMlp', 'SwiGLUFFNFused'} # SwiGLUFFNFused is from DINOv2
24
+ TRANSFORMER_MODULES = ATTENTION_MODULES | MLP_MODULES
25
+
26
+
27
+ def get_LoRA_module_names(id: str) -> Set[str]:
28
+ """ Returns a list of module names that are LoRA-adapted for the given id. """
29
+ id = id.lower()
30
+ if id in ['selfattn', 'selfattention', 'self_attn', 'self_attention']:
31
+ return SELF_ATTENTION_MODULES
32
+ elif id in ['crossattn', 'crossattention', 'cross_attn', 'cross_attention']:
33
+ return CROSS_ATTENTION_MODULES
34
+ elif id in ['attn', 'attention']:
35
+ return ATTENTION_MODULES
36
+ elif id in ['mlp']:
37
+ return MLP_MODULES
38
+ elif id in ['all', 'transformer']:
39
+ return TRANSFORMER_MODULES
40
+ else:
41
+ raise ValueError(f'Unknown LoRA module id {id}.')
42
+
43
+
44
+ class LoRAWrapper(nn.Module):
45
+ """Low-Rank Adaptation Wrapper for linear layers.
46
+ See https://arxiv.org/abs/2106.09685
47
+
48
+ Args:
49
+ linear: nn.Linear layer to wrap
50
+ rank: Rank of adaptation matrix B@A
51
+ scale: x = W_0@x + scale * B@A@x
52
+ num_packed_linear: Set to > 1 when wrapping e.g. packed kv, or qkv attention weights.
53
+ Weights will be initialized as if num_packed_linear = 1, but the LoRA bottleneck will
54
+ be num_packed_linear times larger.
55
+ """
56
+ def __init__(self, linear: nn.Module, rank: int = 4, scale: float = 1.0, num_packed_linear: int = 1):
57
+ super().__init__()
58
+ self.rank = rank
59
+ self.scale = scale
60
+ self.in_features, self.out_features = linear.in_features, linear.out_features
61
+ assert num_packed_linear * rank <= min(self.in_features, self.out_features), \
62
+ f'LoRA rank {num_packed_linear} * {rank} must be less or equal than {min(self.in_features, self.out_features)}'
63
+
64
+ self.linear = linear
65
+ self.lora_down = nn.Linear(self.in_features, num_packed_linear*rank, bias=False)
66
+ self.lora_up = nn.Linear(num_packed_linear*rank, self.out_features, bias=False)
67
+
68
+ nn.init.normal_(self.lora_down.weight, std=1/rank)
69
+ nn.init.zeros_(self.lora_up.weight)
70
+
71
+ def fuse_LoRA_into_linear(self) -> nn.Linear:
72
+ """ Returns a single nn.Linear layer with the LoRA matrix fused into the original one. """
73
+ fused_linear = nn.Linear(self.in_features, self.out_features, bias=self.linear.bias is not None)
74
+ fused_linear.weight.data = self.linear.weight + self.scale * (self.lora_up.weight @ self.lora_down.weight)
75
+ if self.linear.bias is not None:
76
+ fused_linear.bias.data = self.linear.bias
77
+ return fused_linear
78
+
79
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
80
+ """ LoRA adapted linear layer forward pass. """
81
+ return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
82
+
83
+
84
+ def _find_modules(
85
+ model,
86
+ ancestor_class: Optional[Set[str]] = None,
87
+ search_class: List[Type[nn.Module]] = [nn.Linear],
88
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [LoRAWrapper],
89
+ ):
90
+ """
91
+ Find all modules of a certain class (or union of classes) that are direct or
92
+ indirect descendants of other modules of a certain class (or union of classes).
93
+
94
+ Returns all matching modules, along with the parent of those moduless and the
95
+ names they are referenced by.
96
+
97
+ Adapted from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
98
+ """
99
+ # Get the targets we should replace all linears under
100
+ if ancestor_class is not None:
101
+ ancestors = (
102
+ module
103
+ for module in model.modules()
104
+ if module.__class__.__name__ in ancestor_class
105
+ )
106
+ else:
107
+ # this, incase you want to naively iterate over all modules.
108
+ ancestors = [module for module in model.modules()]
109
+
110
+ # For each target find every linear_class module that isn't a child of a LoRA layer
111
+ for ancestor in ancestors:
112
+ for fullname, module in ancestor.named_modules():
113
+ if any([isinstance(module, _class) for _class in search_class]):
114
+ # Find the direct parent if this is a descendant, not a child, of target
115
+ *path, name = fullname.split(".")
116
+ parent = ancestor
117
+ while path:
118
+ parent = parent.get_submodule(path.pop(0))
119
+ # Skip this linear if it's a child of a LoRA layer
120
+ if exclude_children_of and any(
121
+ [isinstance(parent, _class) for _class in exclude_children_of]
122
+ ):
123
+ continue
124
+ # Otherwise, yield it
125
+ yield parent, name, module
126
+
127
+
128
+ def inject_trainable_LoRA(
129
+ model: nn.Module,
130
+ rank: int = 4,
131
+ scale: float = 1.0,
132
+ target_replace_modules: Set[str] = ATTENTION_MODULES
133
+ ) -> None:
134
+ """Replaces all linear layers of the specified modules with LoRA-adapted linear layers.
135
+ Modifies the model in-place.
136
+
137
+ Args:
138
+ model: nn.Module to modify
139
+ rank: Rank of adaptation matrix B@A
140
+ scale: x = W_0@x + scale * B@A@x
141
+ target_replace_modules: Set of module names to replace linear layers in.
142
+ """
143
+ for _module, name, _child_module in _find_modules(
144
+ model, target_replace_modules, search_class=[nn.Linear]
145
+ ):
146
+ if sorted(name) == sorted('qkv'):
147
+ num_packed_linear = 3
148
+ elif sorted(name) in [sorted('kv'), sorted('qk'), sorted('qv')]:
149
+ num_packed_linear = 2
150
+ else:
151
+ num_packed_linear = 1
152
+
153
+ _module._modules[name] = LoRAWrapper(_child_module, rank=rank, scale=scale, num_packed_linear=num_packed_linear)
154
+
155
+
156
+ def fuse_LoRA_into_linear(
157
+ model: nn.Module,
158
+ target_replace_modules: Set[str] = ATTENTION_MODULES
159
+ ) -> None:
160
+ """Fuses all LoRA-adapted linear layers back into single linear layers.
161
+ Modifies the model in-place.
162
+
163
+ Args:
164
+ model: nn.Module to modify
165
+ target_replace_modules: Set of module names to replace linear layers in.
166
+ """
167
+ for _module, name, _child_module in _find_modules(
168
+ model, target_replace_modules, search_class=[LoRAWrapper]
169
+ ):
170
+ _module._modules[name] = _module._modules[name].fuse_LoRA_into_linear()
171
+
172
+
173
+ def unfreeze_all_LoRA_layers(model: nn.Module) -> None:
174
+ """ Unfreezes all LoRA-adapted linear layers. """
175
+ for name, param in model.named_parameters():
176
+ if 'lora' in name:
177
+ param.requires_grad = True
fourm/utils/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .misc import *
2
+ from .checkpoint import *
3
+ from .timm.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
4
+ from .data_constants import *
5
+ from .dist import *
6
+ from .logger import *
7
+ from .timm.metrics import AverageMeter, accuracy
8
+ from .timm.mixup import FastCollateMixup, Mixup
9
+ from .timm.model import freeze, get_state_dict, unfreeze, unwrap_model
10
+ from .timm.model_builder import create_model
11
+ from .timm.model_ema import ModelEma, ModelEmaV2
12
+ from .native_scaler import NativeScalerWithGradNormCount
13
+ from .scheduler import cosine_scheduler, constant_scheduler, inverse_sqrt_scheduler
14
+ from .optim_factory import create_optimizer
15
+ from .timm.registry import model_entrypoint, register_model
16
+ from .timm.transforms import *
17
+ from .timm.transforms_factory import create_transform
18
+ from .tokenizer.text_tokenizer import *
19
+ from .s3_utils import *
20
+ from .run_name import *
21
+ from .generation_datasets import *
22
+ from .seeds import *
fourm/utils/checkpoint.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # Based on the timm code base
16
+ # https://github.com/huggingface/pytorch-image-models
17
+ # --------------------------------------------------------
18
+ import io
19
+ import os
20
+ import ast
21
+ import json
22
+ from pathlib import Path
23
+ from safetensors.torch import load as load_st
24
+
25
+ import torch
26
+
27
+ from .dist import save_on_main, is_main_process
28
+ from .timm.model import get_state_dict
29
+ from .s3_utils import save_on_s3
30
+
31
+
32
+ def _load_checkpoint_for_ema(model_ema, checkpoint):
33
+ """
34
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
35
+ """
36
+ mem_file = io.BytesIO()
37
+ torch.save(checkpoint, mem_file)
38
+ mem_file.seek(0)
39
+ model_ema._load_checkpoint(mem_file)
40
+
41
+
42
+ def load_state_dict(model, state_dict, prefix='', ignore_missing=''):
43
+ missing_keys = []
44
+ unexpected_keys = []
45
+ error_msgs = []
46
+ # copy state_dict so _load_from_state_dict can modify it
47
+ metadata = getattr(state_dict, '_metadata', None)
48
+ state_dict = state_dict.copy()
49
+ if metadata is not None:
50
+ state_dict._metadata = metadata
51
+
52
+ def load(module, prefix=''):
53
+ local_metadata = {} if metadata is None else metadata.get(
54
+ prefix[:-1], {})
55
+ module._load_from_state_dict(
56
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
57
+ for name, child in module._modules.items():
58
+ if child is not None:
59
+ load(child, prefix + name + '.')
60
+
61
+ load(model, prefix=prefix)
62
+
63
+ warn_missing_keys = []
64
+ ignore_missing_keys = []
65
+ for key in missing_keys:
66
+ keep_flag = True
67
+ for ignore_key in ignore_missing.split('|'):
68
+ if ignore_key in key:
69
+ keep_flag = False
70
+ break
71
+ if keep_flag:
72
+ warn_missing_keys.append(key)
73
+ else:
74
+ ignore_missing_keys.append(key)
75
+
76
+ missing_keys = warn_missing_keys
77
+
78
+ if len(missing_keys) > 0:
79
+ print("Weights of {} not initialized from pretrained model: {}".format(
80
+ model.__class__.__name__, missing_keys))
81
+ if len(unexpected_keys) > 0:
82
+ print("Weights from pretrained model not used in {}: {}".format(
83
+ model.__class__.__name__, unexpected_keys))
84
+ if len(ignore_missing_keys) > 0:
85
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
86
+ model.__class__.__name__, ignore_missing_keys))
87
+ if len(error_msgs) > 0:
88
+ print('\n'.join(error_msgs))
89
+
90
+
91
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, loss_balancer=None, model_ema=None, ckpt_name=None, use_s3=False, all_nodes=False):
92
+ output_dir = Path(args.output_dir)
93
+ epoch_name = str(epoch)
94
+ ckpt_name = ckpt_name or epoch_name
95
+
96
+ # Only create the save_dict on the main process, unless all_nodes is set to True
97
+ if is_main_process() or (all_nodes and args.gpu == 0):
98
+ checkpoint_path = os.path.join(output_dir, f'checkpoint-{ckpt_name}.pth')
99
+
100
+ to_save = {
101
+ 'model': model_without_ddp.state_dict(),
102
+ 'epoch': epoch,
103
+ 'args': args,
104
+ 'scaler': loss_scaler.state_dict(),
105
+ }
106
+
107
+ if optimizer is not None:
108
+ to_save['optimizer'] = optimizer.state_dict()
109
+
110
+ if loss_balancer is not None:
111
+ to_save['loss_balancer'] = loss_balancer.state_dict()
112
+
113
+ if model_ema is not None:
114
+ to_save['model_ema'] = get_state_dict(model_ema)
115
+
116
+ save_on_main(to_save, checkpoint_path)
117
+
118
+ if use_s3:
119
+ s3_path = os.path.join(args.s3_save_dir, f'checkpoint-{ckpt_name}.pth')
120
+ save_on_s3(checkpoint_path, s3_path, args.s3_endpoint)
121
+
122
+
123
+ def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
124
+ output_dir = Path(args.output_dir)
125
+ # torch.amp
126
+ if args.auto_resume and len(args.resume) == 0:
127
+ import glob
128
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
129
+ latest_ckpt = -1
130
+ for ckpt in all_checkpoints:
131
+ t = ckpt.split('-')[-1].split('.')[0]
132
+ if t.isdigit():
133
+ latest_ckpt = max(int(t), latest_ckpt)
134
+ if latest_ckpt >= 0:
135
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
136
+ print("Auto resume checkpoint: %s" % args.resume)
137
+
138
+ if args.resume:
139
+ if args.resume.startswith('https'):
140
+ checkpoint = torch.hub.load_state_dict_from_url(
141
+ args.resume, map_location='cpu')
142
+ else:
143
+ checkpoint = torch.load(args.resume, map_location='cpu')
144
+ model_without_ddp.load_state_dict(checkpoint['model'])
145
+ print("Resume checkpoint %s" % args.resume)
146
+
147
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
148
+ optimizer.load_state_dict(checkpoint['optimizer'])
149
+ args.start_epoch = checkpoint['epoch'] + 1
150
+
151
+ if 'scaler' in checkpoint:
152
+ loss_scaler.load_state_dict(checkpoint['scaler'])
153
+ print("With optim & sched!")
154
+
155
+ if hasattr(args, 'model_ema') and args.model_ema:
156
+ _load_checkpoint_for_ema(model_ema, {'state_dict_ema': checkpoint['model_ema']})
157
+ print("With EMA!")
158
+
159
+ def parse_metadata(metadata_str):
160
+ metadata = {}
161
+ for k, v in metadata_str.items():
162
+ try:
163
+ v_parsed = ast.literal_eval(v)
164
+ except:
165
+ v_parsed = v
166
+ metadata[k] = v_parsed
167
+ return metadata
168
+
169
+ def load_safetensors(safetensors_path, return_metadata=True):
170
+ with open(safetensors_path, 'rb') as f:
171
+ data = f.read()
172
+
173
+ tensors = load_st(data)
174
+
175
+ if not return_metadata:
176
+ return tensors
177
+
178
+ n_header = data[:8]
179
+ n = int.from_bytes(n_header, "little")
180
+ metadata_bytes = data[8 : 8 + n]
181
+ header = json.loads(metadata_bytes)
182
+ metadata = header.get("__metadata__", {})
183
+ metadata = parse_metadata(metadata)
184
+
185
+ return tensors, metadata
fourm/utils/clip/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .clip import *
2
+ from .model import *
fourm/utils/clip/clip.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the CLIP code base
3
+ # https://github.com/openai/CLIP
4
+ # --------------------------------------------------------
5
+
6
+ import hashlib
7
+ import os
8
+ import urllib
9
+ import warnings
10
+ from typing import Any, Union, List
11
+ from pkg_resources import packaging
12
+
13
+ import torch
14
+ from PIL import Image
15
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
16
+ from tqdm import tqdm
17
+
18
+ from .model import build_model
19
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
20
+
21
+ try:
22
+ from torchvision.transforms import InterpolationMode
23
+ BICUBIC = InterpolationMode.BICUBIC
24
+ except ImportError:
25
+ BICUBIC = Image.BICUBIC
26
+
27
+
28
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
29
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
30
+
31
+
32
+ __all__ = ["available_models", "load", "tokenize"]
33
+ _tokenizer = _Tokenizer()
34
+
35
+ _MODELS = {
36
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
37
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
38
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
39
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
40
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
41
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
42
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
43
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
44
+ }
45
+
46
+
47
+ def _download(url: str, root: str):
48
+ os.makedirs(root, exist_ok=True)
49
+ filename = os.path.basename(url)
50
+
51
+ expected_sha256 = url.split("/")[-2]
52
+ download_target = os.path.join(root, filename)
53
+
54
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
55
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
56
+
57
+ if os.path.isfile(download_target):
58
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
59
+ return download_target
60
+ else:
61
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
62
+
63
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
64
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
65
+ while True:
66
+ buffer = source.read(8192)
67
+ if not buffer:
68
+ break
69
+
70
+ output.write(buffer)
71
+ loop.update(len(buffer))
72
+
73
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
74
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
75
+
76
+ return download_target
77
+
78
+
79
+ def _convert_image_to_rgb(image):
80
+ return image.convert("RGB")
81
+
82
+
83
+ def _transform(n_px):
84
+ return Compose([
85
+ Resize(n_px, interpolation=BICUBIC),
86
+ CenterCrop(n_px),
87
+ _convert_image_to_rgb,
88
+ ToTensor(),
89
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
90
+ ])
91
+
92
+
93
+ def available_models() -> List[str]:
94
+ """Returns the names of available CLIP models"""
95
+ return list(_MODELS.keys())
96
+
97
+
98
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
99
+ """Load a CLIP model
100
+
101
+ Parameters
102
+ ----------
103
+ name : str
104
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
105
+
106
+ device : Union[str, torch.device]
107
+ The device to put the loaded model
108
+
109
+ jit : bool
110
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
111
+
112
+ download_root: str
113
+ path to download the model files; by default, it uses "~/.cache/clip"
114
+
115
+ Returns
116
+ -------
117
+ model : torch.nn.Module
118
+ The CLIP model
119
+
120
+ preprocess : Callable[[PIL.Image], torch.Tensor]
121
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
122
+ """
123
+ if name in _MODELS:
124
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
125
+ elif os.path.isfile(name):
126
+ model_path = name
127
+ else:
128
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
129
+
130
+ try:
131
+ # loading JIT archive
132
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
133
+ state_dict = None
134
+ except RuntimeError:
135
+ # loading saved state dict
136
+ if jit:
137
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
138
+ jit = False
139
+ state_dict = torch.load(model_path, map_location="cpu")
140
+
141
+ if not jit:
142
+ model = build_model(state_dict or model.state_dict()).to(device)
143
+ if str(device) == "cpu":
144
+ model.float()
145
+ return model, _transform(model.visual.input_resolution)
146
+
147
+ # patch the device names
148
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
149
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
150
+
151
+ def patch_device(module):
152
+ try:
153
+ graphs = [module.graph] if hasattr(module, "graph") else []
154
+ except RuntimeError:
155
+ graphs = []
156
+
157
+ if hasattr(module, "forward1"):
158
+ graphs.append(module.forward1.graph)
159
+
160
+ for graph in graphs:
161
+ for node in graph.findAllNodes("prim::Constant"):
162
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
163
+ node.copyAttributes(device_node)
164
+
165
+ model.apply(patch_device)
166
+ patch_device(model.encode_image)
167
+ patch_device(model.encode_text)
168
+
169
+ # patch dtype to float32 on CPU
170
+ if str(device) == "cpu":
171
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
172
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
173
+ float_node = float_input.node()
174
+
175
+ def patch_float(module):
176
+ try:
177
+ graphs = [module.graph] if hasattr(module, "graph") else []
178
+ except RuntimeError:
179
+ graphs = []
180
+
181
+ if hasattr(module, "forward1"):
182
+ graphs.append(module.forward1.graph)
183
+
184
+ for graph in graphs:
185
+ for node in graph.findAllNodes("aten::to"):
186
+ inputs = list(node.inputs())
187
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
188
+ if inputs[i].node()["value"] == 5:
189
+ inputs[i].node().copyAttributes(float_node)
190
+
191
+ model.apply(patch_float)
192
+ patch_float(model.encode_image)
193
+ patch_float(model.encode_text)
194
+
195
+ model.float()
196
+
197
+ return model, _transform(model.input_resolution.item())
198
+
199
+
200
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
201
+ """
202
+ Returns the tokenized representation of given input string(s)
203
+
204
+ Parameters
205
+ ----------
206
+ texts : Union[str, List[str]]
207
+ An input string or a list of input strings to tokenize
208
+
209
+ context_length : int
210
+ The context length to use; all CLIP models use 77 as the context length
211
+
212
+ truncate: bool
213
+ Whether to truncate the text in case its encoding is longer than the context length
214
+
215
+ Returns
216
+ -------
217
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
218
+ """
219
+ if isinstance(texts, str):
220
+ texts = [texts]
221
+
222
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
223
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
224
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
225
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
226
+
227
+ for i, tokens in enumerate(all_tokens):
228
+ if len(tokens) > context_length:
229
+ if truncate:
230
+ tokens = tokens[:context_length]
231
+ tokens[-1] = eot_token
232
+ else:
233
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
234
+ result[i, :len(tokens)] = torch.tensor(tokens)
235
+
236
+ return result
fourm/utils/clip/model.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the CLIP code base
3
+ # https://github.com/openai/CLIP
4
+ # --------------------------------------------------------
5
+
6
+ from collections import OrderedDict
7
+ from typing import Tuple, Union
8
+ import math
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+
15
+ class Bottleneck(nn.Module):
16
+ expansion = 4
17
+
18
+ def __init__(self, inplanes, planes, stride=1):
19
+ super().__init__()
20
+
21
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
22
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+
25
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
26
+ self.bn2 = nn.BatchNorm2d(planes)
27
+
28
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
29
+
30
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
31
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
32
+
33
+ self.relu = nn.ReLU(inplace=True)
34
+ self.downsample = None
35
+ self.stride = stride
36
+
37
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
38
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
39
+ self.downsample = nn.Sequential(OrderedDict([
40
+ ("-1", nn.AvgPool2d(stride)),
41
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
42
+ ("1", nn.BatchNorm2d(planes * self.expansion))
43
+ ]))
44
+
45
+ def forward(self, x: torch.Tensor):
46
+ identity = x
47
+
48
+ out = self.relu(self.bn1(self.conv1(x)))
49
+ out = self.relu(self.bn2(self.conv2(out)))
50
+ out = self.avgpool(out)
51
+ out = self.bn3(self.conv3(out))
52
+
53
+ if self.downsample is not None:
54
+ identity = self.downsample(x)
55
+
56
+ out += identity
57
+ out = self.relu(out)
58
+ return out
59
+
60
+
61
+ class AttentionPool2d(nn.Module):
62
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
63
+ super().__init__()
64
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
65
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
68
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
69
+ self.num_heads = num_heads
70
+
71
+ def forward(self, x, return_all_tokens=False):
72
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
73
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
74
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
75
+ x, _ = F.multi_head_attention_forward(
76
+ query=x, key=x, value=x,
77
+ embed_dim_to_check=x.shape[-1],
78
+ num_heads=self.num_heads,
79
+ q_proj_weight=self.q_proj.weight,
80
+ k_proj_weight=self.k_proj.weight,
81
+ v_proj_weight=self.v_proj.weight,
82
+ in_proj_weight=None,
83
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
84
+ bias_k=None,
85
+ bias_v=None,
86
+ add_zero_attn=False,
87
+ dropout_p=0,
88
+ out_proj_weight=self.c_proj.weight,
89
+ out_proj_bias=self.c_proj.bias,
90
+ use_separate_proj_weight=True,
91
+ training=self.training,
92
+ need_weights=False
93
+ )
94
+ if return_all_tokens:
95
+ return x
96
+ else:
97
+ return x[0]
98
+
99
+
100
+ class ModifiedResNet(nn.Module):
101
+ """
102
+ A ResNet class that is similar to torchvision's but contains the following changes:
103
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
104
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
105
+ - The final pooling layer is a QKV attention instead of an average pool
106
+ """
107
+
108
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
109
+ super().__init__()
110
+ self.output_dim = output_dim
111
+ self.input_resolution = input_resolution
112
+
113
+ # the 3-layer stem
114
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
115
+ self.bn1 = nn.BatchNorm2d(width // 2)
116
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
117
+ self.bn2 = nn.BatchNorm2d(width // 2)
118
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
119
+ self.bn3 = nn.BatchNorm2d(width)
120
+ self.avgpool = nn.AvgPool2d(2)
121
+ self.relu = nn.ReLU(inplace=True)
122
+
123
+ # residual layers
124
+ self._inplanes = width # this is a *mutable* variable used during construction
125
+ self.layer1 = self._make_layer(width, layers[0])
126
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
127
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
128
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
129
+
130
+ embed_dim = width * 32 # the ResNet feature dimension
131
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
132
+
133
+ def _make_layer(self, planes, blocks, stride=1):
134
+ layers = [Bottleneck(self._inplanes, planes, stride)]
135
+
136
+ self._inplanes = planes * Bottleneck.expansion
137
+ for _ in range(1, blocks):
138
+ layers.append(Bottleneck(self._inplanes, planes))
139
+
140
+ return nn.Sequential(*layers)
141
+
142
+ def forward(self, x, return_side_out=False, return_all_tokens=False):
143
+ def stem(x):
144
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
145
+ x = self.relu(bn(conv(x)))
146
+ x = self.avgpool(x)
147
+ return x
148
+ out = []
149
+ x = x.type(self.conv1.weight.dtype)
150
+ x = stem(x)
151
+ x = self.layer1(x)
152
+ if return_side_out:
153
+ out.append(x)
154
+ x = self.layer2(x)
155
+ if return_side_out:
156
+ out.append(x)
157
+ x = self.layer3(x)
158
+ if return_side_out:
159
+ out.append(x)
160
+ x = self.layer4(x)
161
+ if return_side_out:
162
+ out.append(x)
163
+ x = self.attnpool(x, return_all_tokens)
164
+ out.append(x)
165
+ if len(out) == 1:
166
+ return x
167
+ else:
168
+ return out
169
+
170
+
171
+ class LayerNorm(nn.LayerNorm):
172
+ """Subclass torch's LayerNorm to handle fp16."""
173
+
174
+ def forward(self, x: torch.Tensor):
175
+ orig_type = x.dtype
176
+ ret = super().forward(x.type(torch.float32))
177
+ return ret.type(orig_type)
178
+
179
+
180
+ class QuickGELU(nn.Module):
181
+ def forward(self, x: torch.Tensor):
182
+ return x * torch.sigmoid(1.702 * x)
183
+
184
+
185
+ class ResidualAttentionBlock(nn.Module):
186
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
187
+ super().__init__()
188
+
189
+ self.attn = nn.MultiheadAttention(d_model, n_head)
190
+ self.ln_1 = LayerNorm(d_model)
191
+ self.mlp = nn.Sequential(OrderedDict([
192
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
193
+ ("gelu", QuickGELU()),
194
+ ("c_proj", nn.Linear(d_model * 4, d_model))
195
+ ]))
196
+ self.ln_2 = LayerNorm(d_model)
197
+ self.attn_mask = attn_mask
198
+
199
+ def attention(self, x: torch.Tensor):
200
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
201
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
202
+
203
+ def forward(self, x: torch.Tensor):
204
+ x = x + self.attention(self.ln_1(x))
205
+ x = x + self.mlp(self.ln_2(x))
206
+ return x
207
+
208
+
209
+ class Transformer(nn.Module):
210
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
211
+ super().__init__()
212
+ self.width = width
213
+ self.layers = layers
214
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
215
+
216
+ def forward(self, x: torch.Tensor, return_intermediate_out: bool = False):
217
+ if return_intermediate_out:
218
+ output = []
219
+ for block in self.resblocks:
220
+ x = block(x)
221
+ output.append(x)
222
+ return output
223
+
224
+ return self.resblocks(x)
225
+
226
+
227
+ class VisionTransformer(nn.Module):
228
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
229
+ super().__init__()
230
+ self.input_resolution = input_resolution
231
+ self.patch_size = patch_size
232
+ self.output_dim = output_dim
233
+ self.width = width
234
+ self.heads = heads
235
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
236
+
237
+ scale = width ** -0.5
238
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
239
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
240
+ self.ln_pre = LayerNorm(width)
241
+
242
+ self.transformer = Transformer(width, layers, heads)
243
+
244
+ self.ln_post = LayerNorm(width)
245
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
246
+
247
+ def forward(self, x: torch.Tensor, return_all_tokens=False, return_all_final_tokens=False, return_final_tokens_no_cls=False, **kwargs):
248
+
249
+ B, nc, w, h = x.shape
250
+
251
+ x = self.conv1(x) # shape = [*, width, grid, grid]
252
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
253
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
254
+
255
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
256
+
257
+ if x.shape[1] != self.positional_embedding.shape[0]:
258
+ x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype)
259
+ else:
260
+ x = x + self.positional_embedding.to(x.dtype)
261
+
262
+ x = self.ln_pre(x)
263
+
264
+ x = x.permute(1, 0, 2) # NLD -> LND
265
+ x = self.transformer(x)
266
+ x = x.permute(1, 0, 2) # LND -> NLD
267
+
268
+ if return_all_tokens:
269
+ x = self.ln_post(x)
270
+ return x[:, 1:, :]
271
+
272
+ if return_all_final_tokens:
273
+ return self.ln_post(x) @ self.proj
274
+
275
+ if return_final_tokens_no_cls:
276
+ return self.ln_post(x)[:, 1:, :] @ self.proj
277
+
278
+ x = self.ln_post(x[:, 0, :])
279
+
280
+ if self.proj is not None:
281
+ x = x @ self.proj
282
+
283
+ return x
284
+
285
+ def interpolate_pos_encoding(self, x, w, h):
286
+ npatch = x.shape[1] - 1
287
+ N = self.positional_embedding.shape[0] - 1 # 256 for large
288
+ if npatch == N and w == h:
289
+ return self.positional_embedding
290
+ class_pos_embed = self.positional_embedding[[0]]
291
+ patch_pos_embed = self.positional_embedding[1:]
292
+ dim = x.shape[-1]
293
+ w0 = w // self.patch_size
294
+ h0 = h // self.patch_size
295
+ # we add a small number to avoid floating point error in the interpolation
296
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
297
+ w0, h0 = w0 + 0.1, h0 + 0.1
298
+ patch_pos_embed = nn.functional.interpolate(
299
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
300
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
301
+ mode='bicubic',
302
+ )
303
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
304
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
305
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
306
+
307
+
308
+ class CLIP(nn.Module):
309
+ def __init__(self,
310
+ embed_dim: int, # 512
311
+ # vision
312
+ image_resolution: int, # 224
313
+ vision_layers: Union[Tuple[int, int, int, int], int], # 12
314
+ vision_width: int, # 768
315
+ vision_patch_size: int, # 16
316
+ # text
317
+ context_length: int, # 77
318
+ vocab_size: int, # 49408
319
+ transformer_width: int, # 512
320
+ transformer_heads: int, # 8
321
+ transformer_layers: int # 12
322
+ ):
323
+ super().__init__()
324
+ self.context_length = context_length
325
+
326
+ if isinstance(vision_layers, (tuple, list)):
327
+ vision_heads = vision_width * 32 // 64
328
+ self.visual = ModifiedResNet(
329
+ layers=vision_layers,
330
+ output_dim=embed_dim,
331
+ heads=vision_heads,
332
+ input_resolution=image_resolution,
333
+ width=vision_width
334
+ )
335
+ else:
336
+ vision_heads = vision_width // 64
337
+ self.visual = VisionTransformer(
338
+ input_resolution=image_resolution,
339
+ patch_size=vision_patch_size,
340
+ width=vision_width,
341
+ layers=vision_layers,
342
+ heads=vision_heads,
343
+ output_dim=embed_dim
344
+ )
345
+
346
+ self.transformer = Transformer(
347
+ width=transformer_width,
348
+ layers=transformer_layers,
349
+ heads=transformer_heads,
350
+ attn_mask=self.build_attention_mask()
351
+ )
352
+
353
+ self.vocab_size = vocab_size
354
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
355
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
356
+ self.ln_final = LayerNorm(transformer_width)
357
+
358
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
359
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
360
+
361
+ self.initialize_parameters()
362
+
363
+ def initialize_parameters(self):
364
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
365
+ nn.init.normal_(self.positional_embedding, std=0.01)
366
+
367
+ if isinstance(self.visual, ModifiedResNet):
368
+ if self.visual.attnpool is not None:
369
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
370
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
371
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
372
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
373
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
374
+
375
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
376
+ for name, param in resnet_block.named_parameters():
377
+ if name.endswith("bn3.weight"):
378
+ nn.init.zeros_(param)
379
+
380
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
381
+ attn_std = self.transformer.width ** -0.5
382
+ fc_std = (2 * self.transformer.width) ** -0.5
383
+ for block in self.transformer.resblocks:
384
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
385
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
386
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
387
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
388
+
389
+ if self.text_projection is not None:
390
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
391
+
392
+ def build_attention_mask(self):
393
+ # lazily create causal attention mask, with full attention between the vision tokens
394
+ # pytorch uses additive attention mask; fill with -inf
395
+ mask = torch.empty(self.context_length, self.context_length)
396
+ mask.fill_(float("-inf"))
397
+ mask.triu_(1) # zero out the lower diagonal
398
+ return mask
399
+
400
+ @property
401
+ def dtype(self):
402
+ return self.visual.conv1.weight.dtype
403
+
404
+ def encode_image(self, image, return_side_out=False, return_all_tokens=False, return_all_final_tokens=False, **kwargs):
405
+ return self.visual(image.type(self.dtype), return_all_tokens, return_all_final_tokens, **kwargs)
406
+
407
+ def encode_text(self, text, return_all_tokens=False, return_patch_tokens=False):
408
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
409
+
410
+ x = x + self.positional_embedding.type(self.dtype)
411
+ x = x.permute(1, 0, 2) # NLD -> LND
412
+ x = self.transformer(x)
413
+ x = x.permute(1, 0, 2) # LND -> NLD
414
+ x = self.ln_final(x).type(self.dtype)
415
+
416
+ if return_patch_tokens:
417
+ return x
418
+ # x.shape = [batch_size, n_ctx, transformer.width]
419
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
420
+ if return_all_tokens:
421
+ x = x @ self.text_projection
422
+ else:
423
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
424
+ return x
425
+
426
+ def forward(self, image, text):
427
+ image_features = self.encode_image(image)
428
+ text_features = self.encode_text(text)
429
+
430
+ # normalized features
431
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
432
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
433
+
434
+ # cosine similarity as logits
435
+ logit_scale = self.logit_scale.exp()
436
+ logits_per_image = logit_scale * image_features @ text_features.t()
437
+ logits_per_text = logits_per_image.t()
438
+
439
+ # shape = [global_batch_size, global_batch_size]
440
+ return logits_per_image, logits_per_text
441
+
442
+
443
+ def convert_weights(model: nn.Module):
444
+ """Convert applicable model parameters to fp16"""
445
+
446
+ def _convert_weights_to_fp16(l):
447
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
448
+ l.weight.data = l.weight.data.half()
449
+ if l.bias is not None:
450
+ l.bias.data = l.bias.data.half()
451
+
452
+ if isinstance(l, nn.MultiheadAttention):
453
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
454
+ tensor = getattr(l, attr)
455
+ if tensor is not None:
456
+ tensor.data = tensor.data.half()
457
+
458
+ for name in ["text_projection", "proj"]:
459
+ if hasattr(l, name):
460
+ attr = getattr(l, name)
461
+ if attr is not None:
462
+ attr.data = attr.data.half()
463
+
464
+ model.apply(_convert_weights_to_fp16)
465
+
466
+
467
+ def build_model(state_dict: dict):
468
+ vit = "visual.proj" in state_dict
469
+
470
+ if vit:
471
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
472
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
473
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
474
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
475
+ image_resolution = vision_patch_size * grid_size
476
+ else:
477
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
478
+ vision_layers = tuple(counts)
479
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
480
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
481
+ vision_patch_size = None
482
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
483
+ image_resolution = output_width * 32
484
+
485
+ embed_dim = state_dict["text_projection"].shape[1]
486
+ context_length = state_dict["positional_embedding"].shape[0]
487
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
488
+ transformer_width = state_dict["ln_final.weight"].shape[0]
489
+ transformer_heads = transformer_width // 64
490
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
491
+
492
+ model = CLIP(
493
+ embed_dim,
494
+ image_resolution, vision_layers, vision_width, vision_patch_size,
495
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
496
+ )
497
+
498
+ for key in ["input_resolution", "context_length", "vocab_size"]:
499
+ if key in state_dict:
500
+ del state_dict[key]
501
+
502
+ convert_weights(model)
503
+ model.load_state_dict(state_dict)
504
+ return model.eval()
fourm/utils/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the CLIP code base
3
+ # https://github.com/openai/CLIP
4
+ # --------------------------------------------------------
5
+
6
+ import gzip
7
+ import html
8
+ import os
9
+ from functools import lru_cache
10
+
11
+ import ftfy
12
+ import regex as re
13
+
14
+
15
+ @lru_cache()
16
+ def default_bpe():
17
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
18
+
19
+
20
+ @lru_cache()
21
+ def bytes_to_unicode():
22
+ """
23
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
24
+ The reversible bpe codes work on unicode strings.
25
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
26
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
27
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
28
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
29
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
30
+ """
31
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
32
+ cs = bs[:]
33
+ n = 0
34
+ for b in range(2**8):
35
+ if b not in bs:
36
+ bs.append(b)
37
+ cs.append(2**8+n)
38
+ n += 1
39
+ cs = [chr(n) for n in cs]
40
+ return dict(zip(bs, cs))
41
+
42
+
43
+ def get_pairs(word):
44
+ """Return set of symbol pairs in a word.
45
+ Word is represented as tuple of symbols (symbols being variable-length strings).
46
+ """
47
+ pairs = set()
48
+ prev_char = word[0]
49
+ for char in word[1:]:
50
+ pairs.add((prev_char, char))
51
+ prev_char = char
52
+ return pairs
53
+
54
+
55
+ def basic_clean(text):
56
+ text = ftfy.fix_text(text)
57
+ text = html.unescape(html.unescape(text))
58
+ return text.strip()
59
+
60
+
61
+ def whitespace_clean(text):
62
+ text = re.sub(r'\s+', ' ', text)
63
+ text = text.strip()
64
+ return text
65
+
66
+
67
+ class SimpleTokenizer(object):
68
+ def __init__(self, bpe_path: str = default_bpe()):
69
+ self.byte_encoder = bytes_to_unicode()
70
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
71
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
72
+ merges = merges[1:49152-256-2+1]
73
+ merges = [tuple(merge.split()) for merge in merges]
74
+ vocab = list(bytes_to_unicode().values())
75
+ vocab = vocab + [v+'</w>' for v in vocab]
76
+ for merge in merges:
77
+ vocab.append(''.join(merge))
78
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
79
+ self.encoder = dict(zip(vocab, range(len(vocab))))
80
+ self.decoder = {v: k for k, v in self.encoder.items()}
81
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
82
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
83
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
84
+
85
+ def bpe(self, token):
86
+ if token in self.cache:
87
+ return self.cache[token]
88
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
89
+ pairs = get_pairs(word)
90
+
91
+ if not pairs:
92
+ return token+'</w>'
93
+
94
+ while True:
95
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
96
+ if bigram not in self.bpe_ranks:
97
+ break
98
+ first, second = bigram
99
+ new_word = []
100
+ i = 0
101
+ while i < len(word):
102
+ try:
103
+ j = word.index(first, i)
104
+ new_word.extend(word[i:j])
105
+ i = j
106
+ except:
107
+ new_word.extend(word[i:])
108
+ break
109
+
110
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
111
+ new_word.append(first+second)
112
+ i += 2
113
+ else:
114
+ new_word.append(word[i])
115
+ i += 1
116
+ new_word = tuple(new_word)
117
+ word = new_word
118
+ if len(word) == 1:
119
+ break
120
+ else:
121
+ pairs = get_pairs(word)
122
+ word = ' '.join(word)
123
+ self.cache[token] = word
124
+ return word
125
+
126
+ def encode(self, text):
127
+ bpe_tokens = []
128
+ text = whitespace_clean(basic_clean(text)).lower()
129
+ for token in re.findall(self.pat, text):
130
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
131
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
132
+ return bpe_tokens
133
+
134
+ def decode(self, tokens):
135
+ text = ''.join([self.decoder[token] for token in tokens])
136
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
137
+ return text
fourm/utils/data_constants.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ DEFAULT_CROP_PCT = 0.875
15
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
16
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
17
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
18
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
19
+ IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
20
+ IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
21
+
22
+ IMAGENET_SURFACE_NORMAL_MEAN = (0.501, 0.405, 0.137)
23
+ IMAGENET_SURFACE_NORMAL_STD = (0.114, 0.165, 0.081)
24
+
25
+ SEG_IGNORE_INDEX = 255
26
+ SEG_IGNORE_INDEX_V2 = 0
27
+ PAD_MASK_VALUE = 254
28
+ COCO_SEMSEG_NUM_CLASSES = 133 + 1 # One extra class for no-class
29
+ ADE20K_SEMSEG_NUM_CLASSES = 150 + 1 # One extra class for no-class
30
+ HYPERSIM_SEMSEG_NUM_CLASSES = 41
31
+
32
+
33
+ IMAGE_TASKS = {'rgb', 'depth', 'semseg', 'semseg_hypersim', 'semseg_coco', 'semseg_ade20k', 'normal'}
34
+ DETECTION_TASKS = {'det'} # 'det_coco', 'det_lvis'
35
+ TEXT_TASKS = {'caption'}
36
+ VISION_TASKS = IMAGE_TASKS | DETECTION_TASKS
37
+ SEQUENCE_TASKS = DETECTION_TASKS | TEXT_TASKS
38
+
39
+ NYU_MEAN = 2070.7764
40
+ NYU_STD = 777.5723
fourm/utils/dist.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # Based on DETR and MMCV code bases
16
+ # https://github.com/facebookresearch/detr
17
+ # https://github.com/open-mmlab/mmcv
18
+ # --------------------------------------------------------
19
+
20
+ import os
21
+ import pickle
22
+ import shutil
23
+ import sys
24
+ import tempfile
25
+ import datetime
26
+
27
+ import torch
28
+ import torch.distributed as dist
29
+
30
+
31
+ def setup_for_distributed(is_main):
32
+ """
33
+ This function disables printing when not in main process
34
+ """
35
+ import builtins as __builtin__
36
+ builtin_print = __builtin__.print
37
+
38
+ def print(*args, **kwargs):
39
+ force = kwargs.pop('force', False)
40
+ if is_main or force or kwargs.get('file', None) == sys.stderr:
41
+ builtin_print(*args, **kwargs)
42
+
43
+ __builtin__.print = print
44
+
45
+
46
+ def is_dist_avail_and_initialized():
47
+ if not dist.is_available():
48
+ return False
49
+ if not dist.is_initialized():
50
+ return False
51
+ return True
52
+
53
+
54
+ def get_world_size():
55
+ if not is_dist_avail_and_initialized():
56
+ return 1
57
+ return dist.get_world_size()
58
+
59
+
60
+ def get_rank():
61
+ if not is_dist_avail_and_initialized():
62
+ return 0
63
+ return dist.get_rank()
64
+
65
+
66
+ def is_main_process():
67
+ return get_rank() == 0
68
+
69
+
70
+ def save_on_main(*args, **kwargs):
71
+ if is_main_process():
72
+ torch.save(*args, **kwargs)
73
+
74
+ def save_on_all(*args, **kwargs):
75
+ torch.save(*args, **kwargs)
76
+
77
+
78
+ def init_distributed_mode(args):
79
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
80
+ args.rank = int(os.environ["RANK"])
81
+ args.world_size = int(os.environ['WORLD_SIZE'])
82
+ args.gpu = int(os.environ['LOCAL_RANK'])
83
+ else:
84
+ print('Not using distributed mode')
85
+ args.distributed = False
86
+ return
87
+
88
+ args.distributed = True
89
+
90
+ torch.cuda.set_device(args.gpu)
91
+ args.dist_backend = 'nccl'
92
+ print('| distributed init (rank {}): {}, gpu {}'.format(
93
+ args.rank, args.dist_url, args.gpu), flush=True)
94
+ # Set timeout to 1h20 in case some long download of dataset has to happen
95
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
96
+ world_size=args.world_size, rank=args.rank,
97
+ timeout=datetime.timedelta(4800))
98
+ torch.distributed.barrier()
99
+ if ("print_all" not in args) or (not args.print_all):
100
+ setup_for_distributed(args.rank == 0)
fourm/utils/fsdp_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from pathlib import Path
16
+
17
+ import torch
18
+
19
+ from .dist import save_on_main, is_main_process
20
+ from .s3_utils import save_on_s3
21
+
22
+
23
+ from torch.distributed.fsdp import (
24
+ FullyShardedDataParallel as FSDP,
25
+ FullStateDictConfig,
26
+ StateDictType,
27
+ )
28
+
29
+ from torch.distributed.fsdp.api import FullOptimStateDictConfig
30
+
31
+
32
+
33
+ def save_model_fsdp(args, epoch, model, optimizer, model_ema=None, ckpt_name=None, use_s3=False):
34
+ output_dir = Path(args.output_dir)
35
+ epoch_name = str(epoch)
36
+ ckpt_name = ckpt_name or epoch_name
37
+
38
+
39
+ with FSDP.state_dict_type(model,
40
+ StateDictType.FULL_STATE_DICT,
41
+ FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
42
+ FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
43
+ ):
44
+
45
+ model_state_dict = model.state_dict()
46
+ if optimizer is not None:
47
+ optimizer_state_dict = FSDP.optim_state_dict(model, optimizer)
48
+ else:
49
+ optimizer_state_dict = None
50
+
51
+ # Only create the save_dict on the main process, not needed or recommended to do so on all ranks
52
+ # This make save_on_main() redundant
53
+ if is_main_process():
54
+ checkpoint_path = os.path.join(output_dir, f'checkpoint-{ckpt_name}.pth')
55
+
56
+ to_save = {
57
+ 'model': model_state_dict,
58
+ 'epoch': epoch,
59
+ 'args': args,
60
+ }
61
+
62
+ if optimizer is not None:
63
+ to_save['optimizer'] = optimizer_state_dict
64
+
65
+ if model_ema is not None:
66
+ print("Model EMA is currently not supported for FSDP")
67
+ # to_save['model_ema'] = get_state_dict(model_ema)
68
+
69
+ save_on_main(to_save, checkpoint_path)
70
+
71
+ if use_s3:
72
+ s3_path = os.path.join(args.s3_save_dir, f'checkpoint-{ckpt_name}.pth')
73
+ save_on_s3(checkpoint_path, s3_path, args.s3_endpoint)
74
+
75
+
76
+ def auto_load_model_fsdp(args, model, optimizer, model_ema=None):
77
+ output_dir = Path(args.output_dir)
78
+ if args.auto_resume and len(args.resume) == 0:
79
+ import glob
80
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
81
+ latest_ckpt = -1
82
+ for ckpt in all_checkpoints:
83
+ t = ckpt.split('-')[-1].split('.')[0]
84
+ if t.isdigit():
85
+ latest_ckpt = max(int(t), latest_ckpt)
86
+ if latest_ckpt >= 0:
87
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
88
+ print("Auto resume checkpoint: %s" % args.resume)
89
+
90
+ if args.resume:
91
+ if args.resume.startswith('https'):
92
+ checkpoint = torch.hub.load_state_dict_from_url(
93
+ args.resume, map_location='cpu')
94
+ else:
95
+ checkpoint = torch.load(args.resume, map_location='cpu')
96
+
97
+ with FSDP.state_dict_type(
98
+ model,
99
+ StateDictType.FULL_STATE_DICT,
100
+ FullStateDictConfig(rank0_only=False),
101
+ FullOptimStateDictConfig(rank0_only=False),
102
+ ):
103
+
104
+ model.load_state_dict(checkpoint['model'])
105
+ print("Resume checkpoint %s" % args.resume)
106
+
107
+
108
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
109
+ optimizer_state_dict = FSDP.optim_state_dict_to_load(checkpoint['optimizer'], model, optimizer)
110
+ optimizer.load_state_dict(optimizer_state_dict)
111
+ args.start_epoch = checkpoint['epoch'] + 1
112
+
113
+ print("With optim & sched!")
114
+
115
+ if hasattr(args, 'model_ema') and args.model_ema:
116
+ print("Model EMA is currently not supported for FSDP")
fourm/utils/generation.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import math
16
+
17
+
18
+ def sample_to_batch(mod_dict, device, domains):
19
+ mod_dict = {
20
+ modality: {k: v.unsqueeze(0).to(device, non_blocking=True) for k, v in d.items()}
21
+ for modality, d in mod_dict.items() if modality in domains
22
+ }
23
+
24
+ return mod_dict
25
+
26
+
27
+ def unbatch(tensor):
28
+ return tensor.detach().squeeze(0).cpu()
29
+
30
+
31
+ def batch_to_sample(mod_dict, domains):
32
+ mod_dict = {
33
+ modality: {k: unbatch(v) for k, v in d.items()}
34
+ for modality, d in mod_dict.items() if modality in domains
35
+ }
36
+
37
+ return mod_dict
38
+
39
+
40
+ def batch_to_device(mod_dict, device, domains):
41
+ mod_dict = {
42
+ modality: {k: v.to(device, non_blocking=True) for k, v in d.items()}
43
+ for modality, d in mod_dict.items() if modality in domains
44
+ }
45
+
46
+ return mod_dict
47
+
48
+
49
+ def cosine_schedule(num_steps, total_tokens):
50
+ iters = np.arange(num_steps)
51
+ base_value = 1
52
+ final_value = 0
53
+ schedule = np.array(
54
+ [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
55
+ schedule_tokens = [round(total_tokens * i) for i in (schedule[:-1] - schedule[1:])]
56
+ schedule_tokens.append(total_tokens - sum(schedule_tokens))
57
+ return np.array(schedule_tokens)
58
+
59
+
60
+ def linear_schedule(num_steps, total_tokens):
61
+ schedule = np.linspace(0, total_tokens, num_steps + 1, dtype=int)
62
+ schedule_tokens = np.diff(schedule)[::-1]
63
+ schedule_tokens.sort() # Sorts the array in ascending order.
64
+ schedule_tokens = schedule_tokens[::-1] # Reverses the array to descending order.
65
+ return np.trim_zeros(schedule_tokens, 'b') # Trims trailing zeros.
66
+
67
+
68
+ def continue_schedule(schedule, num_current_tokens):
69
+ schedule_cumsum = np.cumsum(schedule)
70
+ keep_mask = schedule_cumsum > num_current_tokens
71
+ diff = schedule_cumsum[keep_mask][0] - num_current_tokens
72
+ new_schedule = schedule[keep_mask]
73
+ new_schedule[0] = diff
74
+ return new_schedule
75
+
76
+
77
+ def decreasing_temp_schedule(max, min, token_schedule):
78
+ schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule)
79
+ temp_schedule = np.array([min + (max - min) * (1 - s) for s in schedule_cumsum])
80
+ return temp_schedule
81
+
82
+
83
+ def onex_temp_schedule(max_t, min_t, token_schedule, power=0.5, min_linspace=1, max_linspace=100):
84
+ """Abitrary temperature schedule for one over x"""
85
+ x = np.linspace(min_linspace, max_linspace, num=sum(token_schedule))
86
+ y = 1/(x**power)
87
+ y = y - min(y)
88
+ y = y / max(y)
89
+ unscaled_schedule = y
90
+ schedule_cumsum = np.cumsum(token_schedule) / np.sum(token_schedule)
91
+ unscaled_schedule = [(1 - cs) * us for us, cs in zip(unscaled_schedule, schedule_cumsum)]
92
+
93
+ temp_schedule = np.array([min_t + (max_t - min_t) * s for s in unscaled_schedule]).clip(min=1e-9)
94
+ return temp_schedule
95
+
96
+
97
+ def linear_temp_schedule(temp, token_schedule):
98
+ """ Temperature that decays the temperature inversely proportional to the token schedule. """
99
+ return np.concatenate([np.array([temp * 1.0]), (temp * (token_schedule.sum() - token_schedule.cumsum()) / token_schedule.sum())[:-1]]).clip(min=1e-9)
fourm/utils/generation_datasets/PartiPrompts.tsv ADDED
The diff for this file is too large to render. See raw diff
 
fourm/utils/generation_datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .parti_prompts_dataset import *
2
+ from .empty_dataset import *
3
+ from .image_caption_dataset import *
fourm/utils/generation_datasets/empty_dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from torch.utils.data import Dataset
16
+
17
+ class EmptyDataset(Dataset):
18
+ """Empty dataset"""
19
+
20
+ def __init__(self, dataset_size: int):
21
+ self.dataset_size = dataset_size
22
+
23
+ def __getitem__(self, index):
24
+ return {}
25
+
26
+ def __len__(self):
27
+ return self.dataset_size
fourm/utils/generation_datasets/image_caption_dataset.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from torch.utils.data import Dataset
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, cast
17
+
18
+ from fourm.data.multimodal_dataset_folder import make_dataset, UNIFIED_EXTENSIONS
19
+ from fourm.data.modality_transforms import get_transform_key, RGBTransform, CaptionTransform, UnifiedDataTransform
20
+
21
+
22
+ class ImageCaptionDataset(Dataset):
23
+ """
24
+ Similar to MultiModalDatasetFolder, but specialized for image-caption datasets.
25
+ """
26
+ def __init__(self,
27
+ root: str,
28
+ augmenter: Optional[Callable] = None,
29
+ modality_paths: Dict[str, str] = None,
30
+ is_valid_file: Optional[Callable[[str], bool]] = None,
31
+ cache=False):
32
+ self.root = root
33
+ self.modality_paths = modality_paths or {}
34
+
35
+ self.modality_transforms = {
36
+ 'rgb': RGBTransform(imagenet_default_mean_and_std=False),
37
+ 'caption': CaptionTransform()
38
+ }
39
+
40
+ self.transform = UnifiedDataTransform(transforms_dict=self.modality_transforms, image_augmenter=augmenter)
41
+
42
+ classes, class_to_idx = self._find_classes(os.path.join(self.root, self.modality_paths.get('caption', 'caption')))
43
+ extensions = UNIFIED_EXTENSIONS if is_valid_file is None else None
44
+
45
+ samples = {
46
+ mod: make_dataset(
47
+ os.path.join(self.root, self.modality_paths.get(mod, mod)),
48
+ class_to_idx,
49
+ extensions,
50
+ is_valid_file,
51
+ cache_path=os.path.join(self.root, 'dataloader_cache', f'{self.modality_paths.get(mod, mod)}.pkl') if cache else None)
52
+ for mod in ['caption', 'rgb']
53
+ }
54
+
55
+ for mod, mod_samples in samples.items():
56
+ if len(mod_samples) == 0:
57
+ msg = "Found 0 logs in subfolders of: {}\n".format(os.path.join(self.root, self.modality_paths.get(mod, mod)))
58
+ if extensions is not None:
59
+ msg += "Supported extensions are: {}".format(",".join(extensions))
60
+ raise RuntimeError(msg)
61
+
62
+ self.extensions = extensions
63
+ self.classes = classes
64
+ self.class_to_idx = class_to_idx
65
+ self.samples = samples
66
+
67
+ def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
68
+ """
69
+ Finds the class folders in a dataset.
70
+
71
+ Args:
72
+ dir (string): Root directory path.
73
+
74
+ Returns:
75
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
76
+
77
+ Ensures:
78
+ No class is a subdirectory of another.
79
+ """
80
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()]
81
+ classes.sort()
82
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
83
+ return classes, class_to_idx
84
+
85
+ def __getitem__(self, index):
86
+
87
+ sample_dict = {}
88
+ for mod in ['caption', 'rgb']:
89
+ path, _ = self.samples[mod][index]
90
+ sample = self.modality_transforms[get_transform_key(mod)].load(path)
91
+ sample_dict[mod] = sample
92
+
93
+ if self.transform is not None:
94
+ sample_dict = self.transform(sample_dict)
95
+
96
+ return sample_dict
97
+
98
+ def __len__(self) -> int:
99
+ return len(list(self.samples.values())[0])
fourm/utils/generation_datasets/parti_prompts_dataset.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EPFL and Apple Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import pandas as pd
16
+ import numpy as np
17
+ from torch.utils.data import Dataset
18
+
19
+
20
+ class PartiPromptsDataset(Dataset):
21
+ """
22
+ Parti Prompts caption dataset.
23
+
24
+ Args:
25
+ text_tokenizer (tokenizers.Tokenizer): The tokenizer to use for encoding the captions.
26
+ max_length (int): The maximum sequence length of the captions.
27
+ parti_prompts_csv (str): The path to the Parti Prompts dataset.
28
+ """
29
+ def __init__(self, text_tokenizer, max_length=128, parti_prompts_csv='fourm/utils/generation_datasets/PartiPrompts.tsv', parti_prompts_t5_embs=None, llm_embedder=None):
30
+ self.text_tokenizer = text_tokenizer
31
+ self.max_length = max_length
32
+ self.parti_prompts = pd.read_csv(parti_prompts_csv, sep='\t')
33
+
34
+ self.pad_id = text_tokenizer.token_to_id("[PAD]")
35
+ self.eos_id = text_tokenizer.token_to_id("[EOS]")
36
+ if parti_prompts_t5_embs is not None:
37
+ # T5 Embeddings are saved as a numpy array, so we need to load it
38
+ self.t5_embs = np.load(parti_prompts_t5_embs)['emb']
39
+ self.t5_masks = np.load(parti_prompts_t5_embs)['mask_valid']
40
+ self.llm_embedder = None
41
+ elif llm_embedder is not None:
42
+ self.t5_embs = None
43
+ self.llm_embedder = llm_embedder
44
+ else:
45
+ self.t5_embs = None
46
+ self.llm_embedder = None
47
+
48
+ def __getitem__(self, index):
49
+ text = self.parti_prompts.Prompt[index]
50
+ seq_ids = self.text_tokenizer.encode(text).ids + [self.eos_id]
51
+
52
+ tensor = torch.ones(self.max_length, dtype=torch.int) * self.pad_id
53
+ tensor[:len(seq_ids)] = torch.tensor(seq_ids, dtype=torch.int)
54
+
55
+ out = {}
56
+ out['caption'] = {'tensor': tensor}
57
+
58
+ if self.t5_embs is not None:
59
+
60
+ t5_emb = torch.tensor(self.t5_embs[index], dtype=torch.float32)
61
+ t5_emb = pad_or_truncate(t5_emb, self.max_length)
62
+
63
+ t5_mask = torch.tensor(self.t5_masks[index], dtype=torch.bool)
64
+ t5_mask = pad_or_truncate(t5_mask, self.max_length)
65
+
66
+ ascii_tensor = text_to_tensor(text, max_length=self.max_length * 10) # Save ASCII as tensor
67
+
68
+ out['t5_caption'] = {
69
+ 'tensor': t5_emb,
70
+ 'mask_valid': t5_mask,
71
+ 'ascii_tensor': ascii_tensor,
72
+ }
73
+ elif self.llm_embedder is not None:
74
+ t5_emb, _, t5_mask = self.llm_embedder.get_text_embeddings([text])
75
+ t5_emb = pad_or_truncate(t5_emb.squeeze(0), self.max_length)
76
+ t5_mask = pad_or_truncate(t5_mask.bool().squeeze(0), self.max_length)
77
+ ascii_tensor = text_to_tensor(text, max_length=self.max_length * 10) # Save ASCII as tensor
78
+
79
+ out['t5_caption'] = {
80
+ 'tensor': t5_emb,
81
+ 'mask_valid': t5_mask,
82
+ 'ascii_tensor': ascii_tensor,
83
+ }
84
+
85
+ return out
86
+
87
+ def __len__(self):
88
+ return len(self.parti_prompts)
89
+
90
+
91
+ def pad_or_truncate(tensor, fixed_length, padding_value=0):
92
+ current_length = tensor.shape[0]
93
+
94
+ if current_length < fixed_length:
95
+ # Calculate padding sizes for all dimensions, but only pad along dim=0
96
+ padding_sizes = [0] * 2 * len(tensor.shape)
97
+ padding_sizes[1] = fixed_length - current_length
98
+ return torch.nn.functional.pad(tensor, padding_sizes, 'constant', padding_value)
99
+ else:
100
+ return tensor[:fixed_length]
101
+
102
+ def text_to_tensor(text, max_length=None):
103
+ """Converts plaintext to a tensor with optional padding."""
104
+ ascii_values = [ord(c) for c in text]
105
+ if max_length:
106
+ while len(ascii_values) < max_length:
107
+ ascii_values.append(0) # Using 0 as the padding value
108
+ return torch.tensor(ascii_values, dtype=torch.int)
109
+
110
+
111
+ def tensor_to_text(tensor):
112
+ """Converts tensor back to plaintext. Assumes padding with zeros."""
113
+ ascii_values = tensor.tolist()
114
+ return ''.join(chr(val) for val in ascii_values if val != 0)
fourm/utils/hmr2_utils/hmr2/__init__.py ADDED
File without changes
fourm/utils/hmr2_utils/hmr2/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .smpl_wrapper import SMPL
2
+ from .hmr2 import HMR2
fourm/utils/hmr2_utils/hmr2/models/backbones/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the ViTPose and 4DHumans code bases
3
+ # https://github.com/ViTAE-Transformer/ViTPose/
4
+ # https://github.com/shubham-goel/4D-Humans
5
+ # --------------------------------------------------------
6
+
7
+ from .vit import vit
8
+
9
+ def create_backbone(cfg):
10
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
11
+ return vit(cfg)
12
+ else:
13
+ raise NotImplementedError('Backbone type is not implemented')
fourm/utils/hmr2_utils/hmr2/models/backbones/vit.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the ViTPose and 4DHumans code bases
3
+ # https://github.com/ViTAE-Transformer/ViTPose/
4
+ # https://github.com/shubham-goel/4D-Humans
5
+ # --------------------------------------------------------
6
+
7
+ import math
8
+
9
+ import torch
10
+ from functools import partial
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+
17
+ def vit(cfg):
18
+ return ViT(
19
+ img_size=(256, 192),
20
+ patch_size=16,
21
+ embed_dim=1280,
22
+ depth=32,
23
+ num_heads=16,
24
+ ratio=1,
25
+ use_checkpoint=False,
26
+ mlp_ratio=4,
27
+ qkv_bias=True,
28
+ drop_path_rate=0.55,
29
+ )
30
+
31
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
32
+ """
33
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
34
+ dimension for the original embeddings.
35
+ Args:
36
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
37
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
38
+ hw (Tuple): size of input image tokens.
39
+
40
+ Returns:
41
+ Absolute positional embeddings after processing with shape (1, H, W, C)
42
+ """
43
+ cls_token = None
44
+ B, L, C = abs_pos.shape
45
+ if has_cls_token:
46
+ cls_token = abs_pos[:, 0:1]
47
+ abs_pos = abs_pos[:, 1:]
48
+
49
+ if ori_h != h or ori_w != w:
50
+ new_abs_pos = F.interpolate(
51
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
52
+ size=(h, w),
53
+ mode="bicubic",
54
+ align_corners=False,
55
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
56
+
57
+ else:
58
+ new_abs_pos = abs_pos
59
+
60
+ if cls_token is not None:
61
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
62
+ return new_abs_pos
63
+
64
+ class DropPath(nn.Module):
65
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
66
+ """
67
+ def __init__(self, drop_prob=None):
68
+ super(DropPath, self).__init__()
69
+ self.drop_prob = drop_prob
70
+
71
+ def forward(self, x):
72
+ return drop_path(x, self.drop_prob, self.training)
73
+
74
+ def extra_repr(self):
75
+ return 'p={}'.format(self.drop_prob)
76
+
77
+ class Mlp(nn.Module):
78
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
79
+ super().__init__()
80
+ out_features = out_features or in_features
81
+ hidden_features = hidden_features or in_features
82
+ self.fc1 = nn.Linear(in_features, hidden_features)
83
+ self.act = act_layer()
84
+ self.fc2 = nn.Linear(hidden_features, out_features)
85
+ self.drop = nn.Dropout(drop)
86
+
87
+ def forward(self, x):
88
+ x = self.fc1(x)
89
+ x = self.act(x)
90
+ x = self.fc2(x)
91
+ x = self.drop(x)
92
+ return x
93
+
94
+ class Attention(nn.Module):
95
+ def __init__(
96
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
97
+ proj_drop=0., attn_head_dim=None,):
98
+ super().__init__()
99
+ self.num_heads = num_heads
100
+ head_dim = dim // num_heads
101
+ self.dim = dim
102
+
103
+ if attn_head_dim is not None:
104
+ head_dim = attn_head_dim
105
+ all_head_dim = head_dim * self.num_heads
106
+
107
+ self.scale = qk_scale or head_dim ** -0.5
108
+
109
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
110
+
111
+ self.attn_drop = nn.Dropout(attn_drop)
112
+ self.proj = nn.Linear(all_head_dim, dim)
113
+ self.proj_drop = nn.Dropout(proj_drop)
114
+
115
+ def forward(self, x):
116
+ B, N, C = x.shape
117
+ qkv = self.qkv(x)
118
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
119
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
120
+
121
+ q = q * self.scale
122
+ attn = (q @ k.transpose(-2, -1))
123
+
124
+ attn = attn.softmax(dim=-1)
125
+ attn = self.attn_drop(attn)
126
+
127
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
128
+ x = self.proj(x)
129
+ x = self.proj_drop(x)
130
+
131
+ return x
132
+
133
+ class Block(nn.Module):
134
+
135
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
136
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
137
+ norm_layer=nn.LayerNorm, attn_head_dim=None
138
+ ):
139
+ super().__init__()
140
+
141
+ self.norm1 = norm_layer(dim)
142
+ self.attn = Attention(
143
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
144
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
145
+ )
146
+
147
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
148
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
149
+ self.norm2 = norm_layer(dim)
150
+ mlp_hidden_dim = int(dim * mlp_ratio)
151
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
152
+
153
+ def forward(self, x):
154
+ x = x + self.drop_path(self.attn(self.norm1(x)))
155
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
156
+ return x
157
+
158
+
159
+ class PatchEmbed(nn.Module):
160
+ """ Image to Patch Embedding
161
+ """
162
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
163
+ super().__init__()
164
+ img_size = to_2tuple(img_size)
165
+ patch_size = to_2tuple(patch_size)
166
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
167
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
168
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
169
+ self.img_size = img_size
170
+ self.patch_size = patch_size
171
+ self.num_patches = num_patches
172
+
173
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
174
+
175
+ def forward(self, x, **kwargs):
176
+ B, C, H, W = x.shape
177
+ x = self.proj(x)
178
+ Hp, Wp = x.shape[2], x.shape[3]
179
+
180
+ x = x.flatten(2).transpose(1, 2)
181
+ return x, (Hp, Wp)
182
+
183
+
184
+ class HybridEmbed(nn.Module):
185
+ """ CNN Feature Map Embedding
186
+ Extract feature map from CNN, flatten, project to embedding dim.
187
+ """
188
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
189
+ super().__init__()
190
+ assert isinstance(backbone, nn.Module)
191
+ img_size = to_2tuple(img_size)
192
+ self.img_size = img_size
193
+ self.backbone = backbone
194
+ if feature_size is None:
195
+ with torch.no_grad():
196
+ training = backbone.training
197
+ if training:
198
+ backbone.eval()
199
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
200
+ feature_size = o.shape[-2:]
201
+ feature_dim = o.shape[1]
202
+ backbone.train(training)
203
+ else:
204
+ feature_size = to_2tuple(feature_size)
205
+ feature_dim = self.backbone.feature_info.channels()[-1]
206
+ self.num_patches = feature_size[0] * feature_size[1]
207
+ self.proj = nn.Linear(feature_dim, embed_dim)
208
+
209
+ def forward(self, x):
210
+ x = self.backbone(x)[-1]
211
+ x = x.flatten(2).transpose(1, 2)
212
+ x = self.proj(x)
213
+ return x
214
+
215
+
216
+ class ViT(nn.Module):
217
+
218
+ def __init__(self,
219
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
220
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
221
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
222
+ frozen_stages=-1, ratio=1, last_norm=True,
223
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
224
+ ):
225
+ # Protect mutable default arguments
226
+ super(ViT, self).__init__()
227
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
228
+ self.num_classes = num_classes
229
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
230
+ self.frozen_stages = frozen_stages
231
+ self.use_checkpoint = use_checkpoint
232
+ self.patch_padding = patch_padding
233
+ self.freeze_attn = freeze_attn
234
+ self.freeze_ffn = freeze_ffn
235
+ self.depth = depth
236
+
237
+ if hybrid_backbone is not None:
238
+ self.patch_embed = HybridEmbed(
239
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
240
+ else:
241
+ self.patch_embed = PatchEmbed(
242
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
243
+ num_patches = self.patch_embed.num_patches
244
+
245
+ # since the pretraining model has class token
246
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
247
+
248
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
249
+
250
+ self.blocks = nn.ModuleList([
251
+ Block(
252
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
253
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
254
+ )
255
+ for i in range(depth)])
256
+
257
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
258
+
259
+ if self.pos_embed is not None:
260
+ trunc_normal_(self.pos_embed, std=.02)
261
+
262
+ self._freeze_stages()
263
+
264
+ def _freeze_stages(self):
265
+ """Freeze parameters."""
266
+ if self.frozen_stages >= 0:
267
+ self.patch_embed.eval()
268
+ for param in self.patch_embed.parameters():
269
+ param.requires_grad = False
270
+
271
+ for i in range(1, self.frozen_stages + 1):
272
+ m = self.blocks[i]
273
+ m.eval()
274
+ for param in m.parameters():
275
+ param.requires_grad = False
276
+
277
+ if self.freeze_attn:
278
+ for i in range(0, self.depth):
279
+ m = self.blocks[i]
280
+ m.attn.eval()
281
+ m.norm1.eval()
282
+ for param in m.attn.parameters():
283
+ param.requires_grad = False
284
+ for param in m.norm1.parameters():
285
+ param.requires_grad = False
286
+
287
+ if self.freeze_ffn:
288
+ self.pos_embed.requires_grad = False
289
+ self.patch_embed.eval()
290
+ for param in self.patch_embed.parameters():
291
+ param.requires_grad = False
292
+ for i in range(0, self.depth):
293
+ m = self.blocks[i]
294
+ m.mlp.eval()
295
+ m.norm2.eval()
296
+ for param in m.mlp.parameters():
297
+ param.requires_grad = False
298
+ for param in m.norm2.parameters():
299
+ param.requires_grad = False
300
+
301
+ def init_weights(self):
302
+ """Initialize the weights in backbone.
303
+ Args:
304
+ pretrained (str, optional): Path to pre-trained weights.
305
+ Defaults to None.
306
+ """
307
+ def _init_weights(m):
308
+ if isinstance(m, nn.Linear):
309
+ trunc_normal_(m.weight, std=.02)
310
+ if isinstance(m, nn.Linear) and m.bias is not None:
311
+ nn.init.constant_(m.bias, 0)
312
+ elif isinstance(m, nn.LayerNorm):
313
+ nn.init.constant_(m.bias, 0)
314
+ nn.init.constant_(m.weight, 1.0)
315
+
316
+ self.apply(_init_weights)
317
+
318
+ def get_num_layers(self):
319
+ return len(self.blocks)
320
+
321
+ @torch.jit.ignore
322
+ def no_weight_decay(self):
323
+ return {'pos_embed', 'cls_token'}
324
+
325
+ def forward_features(self, x):
326
+ B, C, H, W = x.shape
327
+ x, (Hp, Wp) = self.patch_embed(x)
328
+
329
+ if self.pos_embed is not None:
330
+ # fit for multiple GPU training
331
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
332
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
333
+
334
+ for blk in self.blocks:
335
+ if self.use_checkpoint:
336
+ x = checkpoint.checkpoint(blk, x)
337
+ else:
338
+ x = blk(x)
339
+
340
+ x = self.last_norm(x)
341
+
342
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
343
+
344
+ return xp
345
+
346
+ def forward(self, x):
347
+ x = self.forward_features(x)
348
+ return x
349
+
350
+ def train(self, mode=True):
351
+ """Convert the model into training mode."""
352
+ super().train(mode)
353
+ self._freeze_stages()
fourm/utils/hmr2_utils/hmr2/models/components/__init__.py ADDED
File without changes
fourm/utils/hmr2_utils/hmr2/models/components/pose_transformer.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans code base
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # --------------------------------------------------------
5
+
6
+ from inspect import isfunction
7
+ from typing import Callable, Optional
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from einops.layers.torch import Rearrange
12
+ from torch import nn
13
+
14
+ from .t_cond_mlp import (
15
+ AdaptiveLayerNorm1D,
16
+ FrequencyEmbedder,
17
+ normalization_layer,
18
+ )
19
+ # from .vit import Attention, FeedForward
20
+
21
+
22
+ def exists(val):
23
+ return val is not None
24
+
25
+
26
+ def default(val, d):
27
+ if exists(val):
28
+ return val
29
+ return d() if isfunction(d) else d
30
+
31
+
32
+ class PreNorm(nn.Module):
33
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
34
+ super().__init__()
35
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
36
+ self.fn = fn
37
+
38
+ def forward(self, x: torch.Tensor, *args, **kwargs):
39
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
40
+ return self.fn(self.norm(x, *args), **kwargs)
41
+ else:
42
+ return self.fn(self.norm(x), **kwargs)
43
+
44
+
45
+ class FeedForward(nn.Module):
46
+ def __init__(self, dim, hidden_dim, dropout=0.0):
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.Linear(dim, hidden_dim),
50
+ nn.GELU(),
51
+ nn.Dropout(dropout),
52
+ nn.Linear(hidden_dim, dim),
53
+ nn.Dropout(dropout),
54
+ )
55
+
56
+ def forward(self, x):
57
+ return self.net(x)
58
+
59
+
60
+ class Attention(nn.Module):
61
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
62
+ super().__init__()
63
+ inner_dim = dim_head * heads
64
+ project_out = not (heads == 1 and dim_head == dim)
65
+
66
+ self.heads = heads
67
+ self.scale = dim_head**-0.5
68
+
69
+ self.attend = nn.Softmax(dim=-1)
70
+ self.dropout = nn.Dropout(dropout)
71
+
72
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
73
+
74
+ self.to_out = (
75
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
76
+ if project_out
77
+ else nn.Identity()
78
+ )
79
+
80
+ def forward(self, x):
81
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
82
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
83
+
84
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
85
+
86
+ attn = self.attend(dots)
87
+ attn = self.dropout(attn)
88
+
89
+ out = torch.matmul(attn, v)
90
+ out = rearrange(out, "b h n d -> b n (h d)")
91
+ return self.to_out(out)
92
+
93
+
94
+ class CrossAttention(nn.Module):
95
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
96
+ super().__init__()
97
+ inner_dim = dim_head * heads
98
+ project_out = not (heads == 1 and dim_head == dim)
99
+
100
+ self.heads = heads
101
+ self.scale = dim_head**-0.5
102
+
103
+ self.attend = nn.Softmax(dim=-1)
104
+ self.dropout = nn.Dropout(dropout)
105
+
106
+ context_dim = default(context_dim, dim)
107
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
108
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
109
+
110
+ self.to_out = (
111
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
112
+ if project_out
113
+ else nn.Identity()
114
+ )
115
+
116
+ def forward(self, x, context=None):
117
+ context = default(context, x)
118
+ k, v = self.to_kv(context).chunk(2, dim=-1)
119
+ q = self.to_q(x)
120
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
121
+
122
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
123
+
124
+ attn = self.attend(dots)
125
+ attn = self.dropout(attn)
126
+
127
+ out = torch.matmul(attn, v)
128
+ out = rearrange(out, "b h n d -> b n (h d)")
129
+ return self.to_out(out)
130
+
131
+
132
+ class Transformer(nn.Module):
133
+ def __init__(
134
+ self,
135
+ dim: int,
136
+ depth: int,
137
+ heads: int,
138
+ dim_head: int,
139
+ mlp_dim: int,
140
+ dropout: float = 0.0,
141
+ norm: str = "layer",
142
+ norm_cond_dim: int = -1,
143
+ ):
144
+ super().__init__()
145
+ self.layers = nn.ModuleList([])
146
+ for _ in range(depth):
147
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
148
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
149
+ self.layers.append(
150
+ nn.ModuleList(
151
+ [
152
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
153
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
154
+ ]
155
+ )
156
+ )
157
+
158
+ def forward(self, x: torch.Tensor, *args):
159
+ for attn, ff in self.layers:
160
+ x = attn(x, *args) + x
161
+ x = ff(x, *args) + x
162
+ return x
163
+
164
+
165
+ class TransformerCrossAttn(nn.Module):
166
+ def __init__(
167
+ self,
168
+ dim: int,
169
+ depth: int,
170
+ heads: int,
171
+ dim_head: int,
172
+ mlp_dim: int,
173
+ dropout: float = 0.0,
174
+ norm: str = "layer",
175
+ norm_cond_dim: int = -1,
176
+ context_dim: Optional[int] = None,
177
+ ):
178
+ super().__init__()
179
+ self.layers = nn.ModuleList([])
180
+ for _ in range(depth):
181
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
182
+ ca = CrossAttention(
183
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
184
+ )
185
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
186
+ self.layers.append(
187
+ nn.ModuleList(
188
+ [
189
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
190
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
191
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
192
+ ]
193
+ )
194
+ )
195
+
196
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
197
+ if context_list is None:
198
+ context_list = [context] * len(self.layers)
199
+ if len(context_list) != len(self.layers):
200
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
201
+
202
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
203
+ x = self_attn(x, *args) + x
204
+ x = cross_attn(x, *args, context=context_list[i]) + x
205
+ x = ff(x, *args) + x
206
+ return x
207
+
208
+
209
+ class DropTokenDropout(nn.Module):
210
+ def __init__(self, p: float = 0.1):
211
+ super().__init__()
212
+ if p < 0 or p > 1:
213
+ raise ValueError(
214
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
215
+ )
216
+ self.p = p
217
+
218
+ def forward(self, x: torch.Tensor):
219
+ # x: (batch_size, seq_len, dim)
220
+ if self.training and self.p > 0:
221
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
222
+ # TODO: permutation idx for each batch using torch.argsort
223
+ if zero_mask.any():
224
+ x = x[:, ~zero_mask, :]
225
+ return x
226
+
227
+
228
+ class ZeroTokenDropout(nn.Module):
229
+ def __init__(self, p: float = 0.1):
230
+ super().__init__()
231
+ if p < 0 or p > 1:
232
+ raise ValueError(
233
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
234
+ )
235
+ self.p = p
236
+
237
+ def forward(self, x: torch.Tensor):
238
+ # x: (batch_size, seq_len, dim)
239
+ if self.training and self.p > 0:
240
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
241
+ # Zero-out the masked tokens
242
+ x[zero_mask, :] = 0
243
+ return x
244
+
245
+
246
+ class TransformerEncoder(nn.Module):
247
+ def __init__(
248
+ self,
249
+ num_tokens: int,
250
+ token_dim: int,
251
+ dim: int,
252
+ depth: int,
253
+ heads: int,
254
+ mlp_dim: int,
255
+ dim_head: int = 64,
256
+ dropout: float = 0.0,
257
+ emb_dropout: float = 0.0,
258
+ emb_dropout_type: str = "drop",
259
+ emb_dropout_loc: str = "token",
260
+ norm: str = "layer",
261
+ norm_cond_dim: int = -1,
262
+ token_pe_numfreq: int = -1,
263
+ ):
264
+ super().__init__()
265
+ if token_pe_numfreq > 0:
266
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
267
+ self.to_token_embedding = nn.Sequential(
268
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
269
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
270
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
271
+ nn.Linear(token_dim_new, dim),
272
+ )
273
+ else:
274
+ self.to_token_embedding = nn.Linear(token_dim, dim)
275
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
276
+ if emb_dropout_type == "drop":
277
+ self.dropout = DropTokenDropout(emb_dropout)
278
+ elif emb_dropout_type == "zero":
279
+ self.dropout = ZeroTokenDropout(emb_dropout)
280
+ else:
281
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
282
+ self.emb_dropout_loc = emb_dropout_loc
283
+
284
+ self.transformer = Transformer(
285
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
286
+ )
287
+
288
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
289
+ x = inp
290
+
291
+ if self.emb_dropout_loc == "input":
292
+ x = self.dropout(x)
293
+ x = self.to_token_embedding(x)
294
+
295
+ if self.emb_dropout_loc == "token":
296
+ x = self.dropout(x)
297
+ b, n, _ = x.shape
298
+ x += self.pos_embedding[:, :n]
299
+
300
+ if self.emb_dropout_loc == "token_afterpos":
301
+ x = self.dropout(x)
302
+ x = self.transformer(x, *args)
303
+ return x
304
+
305
+
306
+ class TransformerDecoder(nn.Module):
307
+ def __init__(
308
+ self,
309
+ num_tokens: int,
310
+ token_dim: int,
311
+ dim: int,
312
+ depth: int,
313
+ heads: int,
314
+ mlp_dim: int,
315
+ dim_head: int = 64,
316
+ dropout: float = 0.0,
317
+ emb_dropout: float = 0.0,
318
+ emb_dropout_type: str = 'drop',
319
+ norm: str = "layer",
320
+ norm_cond_dim: int = -1,
321
+ context_dim: Optional[int] = None,
322
+ skip_token_embedding: bool = False,
323
+ ):
324
+ super().__init__()
325
+ if not skip_token_embedding:
326
+ self.to_token_embedding = nn.Linear(token_dim, dim)
327
+ else:
328
+ self.to_token_embedding = nn.Identity()
329
+ if token_dim != dim:
330
+ raise ValueError(
331
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
332
+ )
333
+
334
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
335
+ if emb_dropout_type == "drop":
336
+ self.dropout = DropTokenDropout(emb_dropout)
337
+ elif emb_dropout_type == "zero":
338
+ self.dropout = ZeroTokenDropout(emb_dropout)
339
+ elif emb_dropout_type == "normal":
340
+ self.dropout = nn.Dropout(emb_dropout)
341
+
342
+ self.transformer = TransformerCrossAttn(
343
+ dim,
344
+ depth,
345
+ heads,
346
+ dim_head,
347
+ mlp_dim,
348
+ dropout,
349
+ norm=norm,
350
+ norm_cond_dim=norm_cond_dim,
351
+ context_dim=context_dim,
352
+ )
353
+
354
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
355
+ x = self.to_token_embedding(inp)
356
+ b, n, _ = x.shape
357
+
358
+ x = self.dropout(x)
359
+ x += self.pos_embedding[:, :n]
360
+
361
+ x = self.transformer(x, *args, context=context, context_list=context_list)
362
+ return x
363
+
fourm/utils/hmr2_utils/hmr2/models/components/t_cond_mlp.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans code base
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # --------------------------------------------------------
5
+
6
+ import copy
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+
11
+
12
+ class AdaptiveLayerNorm1D(torch.nn.Module):
13
+ def __init__(self, data_dim: int, norm_cond_dim: int):
14
+ super().__init__()
15
+ if data_dim <= 0:
16
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
17
+ if norm_cond_dim <= 0:
18
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
19
+ self.norm = torch.nn.LayerNorm(
20
+ data_dim
21
+ ) # TODO: Check if elementwise_affine=True is correct
22
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
23
+ torch.nn.init.zeros_(self.linear.weight)
24
+ torch.nn.init.zeros_(self.linear.bias)
25
+
26
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
27
+ # x: (batch, ..., data_dim)
28
+ # t: (batch, norm_cond_dim)
29
+ # return: (batch, data_dim)
30
+ x = self.norm(x)
31
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
32
+
33
+ # Add singleton dimensions to alpha and beta
34
+ if x.dim() > 2:
35
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
36
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
37
+
38
+ return x * (1 + alpha) + beta
39
+
40
+
41
+ class SequentialCond(torch.nn.Sequential):
42
+ def forward(self, input, *args, **kwargs):
43
+ for module in self:
44
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
45
+ # print(f'Passing on args to {module}', [a.shape for a in args])
46
+ input = module(input, *args, **kwargs)
47
+ else:
48
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
49
+ input = module(input)
50
+ return input
51
+
52
+
53
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
54
+ if norm == "batch":
55
+ return torch.nn.BatchNorm1d(dim)
56
+ elif norm == "layer":
57
+ return torch.nn.LayerNorm(dim)
58
+ elif norm == "ada":
59
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
60
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
61
+ elif norm is None:
62
+ return torch.nn.Identity()
63
+ else:
64
+ raise ValueError(f"Unknown norm: {norm}")
65
+
66
+
67
+ def linear_norm_activ_dropout(
68
+ input_dim: int,
69
+ output_dim: int,
70
+ activation: torch.nn.Module = torch.nn.ReLU(),
71
+ bias: bool = True,
72
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
73
+ dropout: float = 0.0,
74
+ norm_cond_dim: int = -1,
75
+ ) -> SequentialCond:
76
+ layers = []
77
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
78
+ if norm is not None:
79
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
80
+ layers.append(copy.deepcopy(activation))
81
+ if dropout > 0.0:
82
+ layers.append(torch.nn.Dropout(dropout))
83
+ return SequentialCond(*layers)
84
+
85
+
86
+ def create_simple_mlp(
87
+ input_dim: int,
88
+ hidden_dims: List[int],
89
+ output_dim: int,
90
+ activation: torch.nn.Module = torch.nn.ReLU(),
91
+ bias: bool = True,
92
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
93
+ dropout: float = 0.0,
94
+ norm_cond_dim: int = -1,
95
+ ) -> SequentialCond:
96
+ layers = []
97
+ prev_dim = input_dim
98
+ for hidden_dim in hidden_dims:
99
+ layers.extend(
100
+ linear_norm_activ_dropout(
101
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
102
+ )
103
+ )
104
+ prev_dim = hidden_dim
105
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
106
+ return SequentialCond(*layers)
107
+
108
+
109
+ class ResidualMLPBlock(torch.nn.Module):
110
+ def __init__(
111
+ self,
112
+ input_dim: int,
113
+ hidden_dim: int,
114
+ num_hidden_layers: int,
115
+ output_dim: int,
116
+ activation: torch.nn.Module = torch.nn.ReLU(),
117
+ bias: bool = True,
118
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
119
+ dropout: float = 0.0,
120
+ norm_cond_dim: int = -1,
121
+ ):
122
+ super().__init__()
123
+ if not (input_dim == output_dim == hidden_dim):
124
+ raise NotImplementedError(
125
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
126
+ )
127
+
128
+ layers = []
129
+ prev_dim = input_dim
130
+ for i in range(num_hidden_layers):
131
+ layers.append(
132
+ linear_norm_activ_dropout(
133
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
134
+ )
135
+ )
136
+ prev_dim = hidden_dim
137
+ self.model = SequentialCond(*layers)
138
+ self.skip = torch.nn.Identity()
139
+
140
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
141
+ return x + self.model(x, *args, **kwargs)
142
+
143
+
144
+ class ResidualMLP(torch.nn.Module):
145
+ def __init__(
146
+ self,
147
+ input_dim: int,
148
+ hidden_dim: int,
149
+ num_hidden_layers: int,
150
+ output_dim: int,
151
+ activation: torch.nn.Module = torch.nn.ReLU(),
152
+ bias: bool = True,
153
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
154
+ dropout: float = 0.0,
155
+ num_blocks: int = 1,
156
+ norm_cond_dim: int = -1,
157
+ ):
158
+ super().__init__()
159
+ self.input_dim = input_dim
160
+ self.model = SequentialCond(
161
+ linear_norm_activ_dropout(
162
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
163
+ ),
164
+ *[
165
+ ResidualMLPBlock(
166
+ hidden_dim,
167
+ hidden_dim,
168
+ num_hidden_layers,
169
+ hidden_dim,
170
+ activation,
171
+ bias,
172
+ norm,
173
+ dropout,
174
+ norm_cond_dim,
175
+ )
176
+ for _ in range(num_blocks)
177
+ ],
178
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
182
+ return self.model(x, *args, **kwargs)
183
+
184
+
185
+ class FrequencyEmbedder(torch.nn.Module):
186
+ def __init__(self, num_frequencies, max_freq_log2):
187
+ super().__init__()
188
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
189
+ self.register_buffer("frequencies", frequencies)
190
+
191
+ def forward(self, x):
192
+ # x should be of size (N,) or (N, D)
193
+ N = x.size(0)
194
+ if x.dim() == 1: # (N,)
195
+ x = x.unsqueeze(1) # (N, D) where D=1
196
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
197
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
198
+ s = torch.sin(scaled)
199
+ c = torch.cos(scaled)
200
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
201
+ N, -1
202
+ ) # (N, D * 2 * num_frequencies + D)
203
+ return embedded
204
+
fourm/utils/hmr2_utils/hmr2/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .smpl_head import build_smpl_head
fourm/utils/hmr2_utils/hmr2/models/heads/smpl_head.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans code base
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # --------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import einops
11
+
12
+ from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
13
+ from ..components.pose_transformer import TransformerDecoder
14
+
15
+ def build_smpl_head(cfg):
16
+ smpl_head_type = cfg.MODEL.SMPL_HEAD.get('TYPE', 'hmr')
17
+ if smpl_head_type == 'transformer_decoder':
18
+ return SMPLTransformerDecoderHead(cfg)
19
+ else:
20
+ raise ValueError('Unknown SMPL head type: {}'.format(smpl_head_type))
21
+
22
+ class SMPLTransformerDecoderHead(nn.Module):
23
+ """ Cross-attention based SMPL Transformer decoder
24
+ """
25
+
26
+ def __init__(self, cfg):
27
+ super().__init__()
28
+ self.cfg = cfg
29
+ self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get('JOINT_REP', '6d')
30
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
31
+ npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS + 1)
32
+ self.npose = npose
33
+ self.input_is_mean_shape = cfg.MODEL.SMPL_HEAD.get('TRANSFORMER_INPUT', 'zero') == 'mean_shape'
34
+ transformer_args = dict(
35
+ num_tokens=1,
36
+ token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1,
37
+ dim=1024,
38
+ )
39
+ transformer_args = (transformer_args | dict(cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER))
40
+ self.transformer = TransformerDecoder(
41
+ **transformer_args
42
+ )
43
+ dim=transformer_args['dim']
44
+ self.decpose = nn.Linear(dim, npose)
45
+ self.decshape = nn.Linear(dim, 10)
46
+ self.deccam = nn.Linear(dim, 3)
47
+
48
+ if cfg.MODEL.SMPL_HEAD.get('INIT_DECODER_XAVIER', False):
49
+ # True by default in MLP. False by default in Transformer
50
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
51
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
52
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
53
+
54
+ mean_params = np.load(cfg.SMPL.MEAN_PARAMS)
55
+ init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
56
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
57
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
58
+ self.register_buffer('init_body_pose', init_body_pose)
59
+ self.register_buffer('init_betas', init_betas)
60
+ self.register_buffer('init_cam', init_cam)
61
+
62
+ def forward(self, x, **kwargs):
63
+
64
+ batch_size = x.shape[0]
65
+ # vit pretrained backbone is channel-first. Change to token-first
66
+ x = einops.rearrange(x, 'b c h w -> b (h w) c')
67
+
68
+ init_body_pose = self.init_body_pose.expand(batch_size, -1)
69
+ init_betas = self.init_betas.expand(batch_size, -1)
70
+ init_cam = self.init_cam.expand(batch_size, -1)
71
+
72
+ # TODO: Convert init_body_pose to aa rep if needed
73
+ if self.joint_rep_type == 'aa':
74
+ raise NotImplementedError
75
+
76
+ pred_body_pose = init_body_pose
77
+ pred_betas = init_betas
78
+ pred_cam = init_cam
79
+ pred_body_pose_list = []
80
+ pred_betas_list = []
81
+ pred_cam_list = []
82
+ for i in range(self.cfg.MODEL.SMPL_HEAD.get('IEF_ITERS', 1)):
83
+ # Input token to transformer is zero token
84
+ if self.input_is_mean_shape:
85
+ token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[:,None,:]
86
+ else:
87
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
88
+
89
+ # Pass through transformer
90
+ token_out = self.transformer(token, context=x)
91
+ token_out = token_out.squeeze(1) # (B, C)
92
+
93
+ # Readout from token_out
94
+ pred_body_pose = self.decpose(token_out) + pred_body_pose
95
+ pred_betas = self.decshape(token_out) + pred_betas
96
+ pred_cam = self.deccam(token_out) + pred_cam
97
+ pred_body_pose_list.append(pred_body_pose)
98
+ pred_betas_list.append(pred_betas)
99
+ pred_cam_list.append(pred_cam)
100
+
101
+ # Convert self.joint_rep_type -> rotmat
102
+ joint_conversion_fn = {
103
+ '6d': rot6d_to_rotmat,
104
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
105
+ }[self.joint_rep_type]
106
+
107
+ pred_smpl_params_list = {}
108
+ pred_smpl_params_list['body_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_body_pose_list], dim=0)
109
+ pred_smpl_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
110
+ pred_smpl_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
111
+ pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, self.cfg.SMPL.NUM_BODY_JOINTS+1, 3, 3)
112
+
113
+ pred_smpl_params = {'global_orient': pred_body_pose[:, [0]],
114
+ 'body_pose': pred_body_pose[:, 1:],
115
+ 'betas': pred_betas}
116
+ return pred_smpl_params, pred_cam, pred_smpl_params_list
fourm/utils/hmr2_utils/hmr2/models/hmr2.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans code base
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # --------------------------------------------------------
5
+
6
+ import torch
7
+ from typing import Any, Dict, Mapping, Tuple
8
+
9
+ from yacs.config import CfgNode
10
+
11
+ from ..utils import SkeletonRenderer, MeshRenderer
12
+ from ..utils.geometry import perspective_projection
13
+ from .backbones import create_backbone
14
+ from .heads import build_smpl_head
15
+ from . import SMPL
16
+
17
+ class HMR2(torch.nn.Module):
18
+
19
+ def __init__(self, cfg: CfgNode, init_renderer: bool = True):
20
+ """
21
+ Setup HMR2 model
22
+ Args:
23
+ cfg (CfgNode): Config file as a yacs CfgNode
24
+ """
25
+ super().__init__()
26
+
27
+ # Save hyperparameters
28
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
29
+
30
+ self.cfg = cfg
31
+ # Create backbone feature extractor
32
+ self.backbone = create_backbone(cfg)
33
+ if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
34
+ self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'])
35
+
36
+ # Create SMPL head
37
+ self.smpl_head = build_smpl_head(cfg)
38
+
39
+ # Instantiate SMPL model
40
+ smpl_cfg = {k.lower(): v for k,v in dict(cfg.SMPL).items()}
41
+ self.smpl = SMPL(**smpl_cfg)
42
+
43
+ # Buffer that shows whetheer we need to initialize ActNorm layers
44
+ self.register_buffer('initialized', torch.tensor(False))
45
+ # Setup renderer for visualization
46
+ if init_renderer:
47
+ self.renderer = SkeletonRenderer(self.cfg)
48
+ self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smpl.faces)
49
+ else:
50
+ self.renderer = None
51
+ self.mesh_renderer = None
52
+
53
+ # Disable automatic optimization since we use adversarial training
54
+ self.automatic_optimization = False
55
+
56
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
57
+ """
58
+ Run a forward step of the network
59
+ Args:
60
+ batch (Dict): Dictionary containing batch data
61
+ train (bool): Flag indicating whether it is training or validation mode
62
+ Returns:
63
+ Dict: Dictionary containing the regression output
64
+ """
65
+
66
+ # Use RGB image as input
67
+ x = batch['img']
68
+ batch_size = x.shape[0]
69
+
70
+ # Compute conditioning features using the backbone
71
+ # if using ViT backbone, we need to use a different aspect ratio
72
+ conditioning_feats = self.backbone(x[:,:,:,32:-32])
73
+
74
+ pred_smpl_params, pred_cam, _ = self.smpl_head(conditioning_feats)
75
+
76
+ # Store useful regression outputs to the output dict
77
+ output = {}
78
+ output['pred_cam'] = pred_cam
79
+ output['pred_smpl_params'] = {k: v.clone() for k,v in pred_smpl_params.items()}
80
+
81
+ # Compute camera translation
82
+ device = pred_smpl_params['body_pose'].device
83
+ dtype = pred_smpl_params['body_pose'].dtype
84
+ focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
85
+ pred_cam_t = torch.stack([pred_cam[:, 1],
86
+ pred_cam[:, 2],
87
+ 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
88
+ output['pred_cam_t'] = pred_cam_t
89
+ output['focal_length'] = focal_length
90
+
91
+ # Compute model vertices, joints and the projected joints
92
+ pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3)
93
+ pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3)
94
+ pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1)
95
+ smpl_output = self.smpl(**{k: v.float() for k,v in pred_smpl_params.items()}, pose2rot=False)
96
+ pred_keypoints_3d = smpl_output.joints
97
+ pred_vertices = smpl_output.vertices
98
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
99
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
100
+ pred_cam_t = pred_cam_t.reshape(-1, 3)
101
+ focal_length = focal_length.reshape(-1, 2)
102
+ pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
103
+ translation=pred_cam_t,
104
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
105
+
106
+ output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
107
+ return output
108
+
109
+ def forward(self, batch: Dict) -> Dict:
110
+ """
111
+ Run a forward step of the network in val mode
112
+ Args:
113
+ batch (Dict): Dictionary containing batch data
114
+ Returns:
115
+ Dict: Dictionary containing the regression output
116
+ """
117
+ return self.forward_step(batch, train=False)
fourm/utils/hmr2_utils/hmr2/models/smpl_wrapper.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans and ProHMR code bases
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # https://github.com/nkolot/ProHMR
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import numpy as np
9
+ import pickle
10
+ from typing import Optional
11
+ import smplx
12
+ from smplx.lbs import vertices2joints
13
+ from smplx.utils import SMPLOutput
14
+
15
+
16
+ class SMPL(smplx.SMPLLayer):
17
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
18
+ """
19
+ Extension of the official SMPL implementation to support more joints.
20
+ Args:
21
+ Same as SMPLLayer.
22
+ joint_regressor_extra (str): Path to extra joint regressor.
23
+ """
24
+ super(SMPL, self).__init__(*args, **kwargs)
25
+ smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
26
+ 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
27
+
28
+ if joint_regressor_extra is not None:
29
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
30
+ self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long))
31
+ self.update_hips = update_hips
32
+
33
+ def forward(self, *args, **kwargs) -> SMPLOutput:
34
+ """
35
+ Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
36
+ """
37
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
38
+ joints = smpl_output.joints[:, self.joint_map, :]
39
+ if self.update_hips:
40
+ joints[:,[9,12]] = joints[:,[9,12]] + \
41
+ 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \
42
+ 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]]))
43
+ if hasattr(self, 'joint_regressor_extra'):
44
+ extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
45
+ joints = torch.cat([joints, extra_joints], dim=1)
46
+ smpl_output.joints = joints
47
+ return smpl_output
fourm/utils/hmr2_utils/hmr2/utils/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans and ProHMR code bases
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # https://github.com/nkolot/ProHMR
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ from typing import Any
9
+
10
+ from .renderer import Renderer
11
+ from .mesh_renderer import MeshRenderer
12
+ from .skeleton_renderer import SkeletonRenderer
13
+ # from .pose_utils import eval_pose, Evaluator
14
+
15
+ def recursive_to(x: Any, target: torch.device):
16
+ """
17
+ Recursively transfer a batch of data to the target device
18
+ Args:
19
+ x (Any): Batch of data.
20
+ target (torch.device): Target device.
21
+ Returns:
22
+ Batch of data where all tensors are transfered to the target device.
23
+ """
24
+ if isinstance(x, dict):
25
+ return {k: recursive_to(v, target) for k, v in x.items()}
26
+ elif isinstance(x, torch.Tensor):
27
+ return x.to(target)
28
+ elif isinstance(x, list):
29
+ return [recursive_to(i, target) for i in x]
30
+ else:
31
+ return x
fourm/utils/hmr2_utils/hmr2/utils/geometry.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans, ProHMR, and SPIN code bases
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # https://github.com/nkolot/ProHMR
5
+ # https://github.com/nkolot/SPIN
6
+ # --------------------------------------------------------
7
+
8
+ from typing import Optional
9
+ import torch
10
+ from torch.nn import functional as F
11
+
12
+ def aa_to_rotmat(theta: torch.Tensor):
13
+ """
14
+ Convert axis-angle representation to rotation matrix.
15
+ Works by first converting it to a quaternion.
16
+ Args:
17
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
18
+ Returns:
19
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
20
+ """
21
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
22
+ angle = torch.unsqueeze(norm, -1)
23
+ normalized = torch.div(theta, angle)
24
+ angle = angle * 0.5
25
+ v_cos = torch.cos(angle)
26
+ v_sin = torch.sin(angle)
27
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
28
+ return quat_to_rotmat(quat)
29
+
30
+ def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Convert quaternion representation to rotation matrix.
33
+ Args:
34
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
35
+ Returns:
36
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
37
+ """
38
+ norm_quat = quat
39
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
40
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
41
+
42
+ B = quat.size(0)
43
+
44
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
45
+ wx, wy, wz = w*x, w*y, w*z
46
+ xy, xz, yz = x*y, x*z, y*z
47
+
48
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
49
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
50
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
51
+ return rotMat
52
+
53
+
54
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Convert 6D rotation representation to 3x3 rotation matrix.
57
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
58
+ Args:
59
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
60
+ Returns:
61
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
62
+ """
63
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
64
+ a1 = x[:, :, 0]
65
+ a2 = x[:, :, 1]
66
+ b1 = F.normalize(a1)
67
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
68
+ b3 = torch.cross(b1, b2)
69
+ return torch.stack((b1, b2, b3), dim=-1)
70
+
71
+ def perspective_projection(points: torch.Tensor,
72
+ translation: torch.Tensor,
73
+ focal_length: torch.Tensor,
74
+ camera_center: Optional[torch.Tensor] = None,
75
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
76
+ """
77
+ Computes the perspective projection of a set of 3D points.
78
+ Args:
79
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
80
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
81
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
82
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
83
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
84
+ Returns:
85
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
86
+ """
87
+ batch_size = points.shape[0]
88
+ if rotation is None:
89
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
90
+ if camera_center is None:
91
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
92
+ # Populate intrinsic camera matrix K.
93
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
94
+ K[:,0,0] = focal_length[:,0]
95
+ K[:,1,1] = focal_length[:,1]
96
+ K[:,2,2] = 1.
97
+ K[:,:-1, -1] = camera_center
98
+
99
+ # Transform points
100
+ points = torch.einsum('bij,bkj->bki', rotation, points)
101
+ points = points + translation.unsqueeze(1)
102
+
103
+ # Apply perspective distortion
104
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
105
+
106
+ # Apply camera intrinsics
107
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
108
+
109
+ return projected_points[:, :, :-1]
fourm/utils/hmr2_utils/hmr2/utils/mesh_renderer.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans and ProHMR code bases
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # https://github.com/nkolot/ProHMR
5
+ # --------------------------------------------------------
6
+
7
+ import os
8
+ if 'PYOPENGL_PLATFORM' not in os.environ:
9
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
10
+ import torch
11
+ from torchvision.utils import make_grid
12
+ import numpy as np
13
+ import pyrender
14
+ import trimesh
15
+ import cv2
16
+ import torch.nn.functional as F
17
+
18
+ from .render_openpose import render_openpose
19
+
20
+ def create_raymond_lights():
21
+ import pyrender
22
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
23
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
24
+
25
+ nodes = []
26
+
27
+ for phi, theta in zip(phis, thetas):
28
+ xp = np.sin(theta) * np.cos(phi)
29
+ yp = np.sin(theta) * np.sin(phi)
30
+ zp = np.cos(theta)
31
+
32
+ z = np.array([xp, yp, zp])
33
+ z = z / np.linalg.norm(z)
34
+ x = np.array([-z[1], z[0], 0.0])
35
+ if np.linalg.norm(x) == 0:
36
+ x = np.array([1.0, 0.0, 0.0])
37
+ x = x / np.linalg.norm(x)
38
+ y = np.cross(z, x)
39
+
40
+ matrix = np.eye(4)
41
+ matrix[:3,:3] = np.c_[x,y,z]
42
+ nodes.append(pyrender.Node(
43
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
44
+ matrix=matrix
45
+ ))
46
+
47
+ return nodes
48
+
49
+ class MeshRenderer:
50
+
51
+ def __init__(self, cfg, faces=None):
52
+ self.cfg = cfg
53
+ self.focal_length = cfg.EXTRA.FOCAL_LENGTH
54
+ self.img_res = cfg.MODEL.IMAGE_SIZE
55
+ self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
56
+ viewport_height=self.img_res,
57
+ point_size=1.0)
58
+
59
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
60
+ self.faces = faces
61
+
62
+ def visualize(self, vertices, camera_translation, images, focal_length=None, nrow=3, padding=2):
63
+ images_np = np.transpose(images, (0,2,3,1))
64
+ rend_imgs = []
65
+ for i in range(vertices.shape[0]):
66
+ fl = self.focal_length
67
+ rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
68
+ rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
69
+ rend_imgs.append(torch.from_numpy(images[i]))
70
+ rend_imgs.append(rend_img)
71
+ rend_imgs.append(rend_img_side)
72
+ rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
73
+ return rend_imgs
74
+
75
+ def visualize_tensorboard(self, vertices, camera_translation, images, pred_keypoints, gt_keypoints, focal_length=None, nrow=5, padding=2):
76
+ images_np = np.transpose(images, (0,2,3,1))
77
+ rend_imgs = []
78
+ pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
79
+ pred_keypoints = self.img_res * (pred_keypoints + 0.5)
80
+ gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5)
81
+ keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9), (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)]
82
+ for i in range(vertices.shape[0]):
83
+ fl = self.focal_length
84
+ rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
85
+ rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
86
+ body_keypoints = pred_keypoints[i, :25]
87
+ extra_keypoints = pred_keypoints[i, -19:]
88
+ for pair in keypoint_matches:
89
+ body_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
90
+ pred_keypoints_img = render_openpose(255 * images_np[i].copy(), body_keypoints) / 255
91
+ body_keypoints = gt_keypoints[i, :25]
92
+ extra_keypoints = gt_keypoints[i, -19:]
93
+ for pair in keypoint_matches:
94
+ if extra_keypoints[pair[1], -1] > 0 and body_keypoints[pair[0], -1] == 0:
95
+ body_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
96
+ gt_keypoints_img = render_openpose(255*images_np[i].copy(), body_keypoints) / 255
97
+ rend_imgs.append(torch.from_numpy(images[i]))
98
+ rend_imgs.append(rend_img)
99
+ rend_imgs.append(rend_img_side)
100
+ rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1))
101
+ rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1))
102
+ rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
103
+ return rend_imgs
104
+
105
+ def __call__(self, vertices, camera_translation, image, focal_length=5000, text=None, resize=None, side_view=False, baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90):
106
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
107
+ viewport_height=image.shape[0],
108
+ point_size=1.0)
109
+ material = pyrender.MetallicRoughnessMaterial(
110
+ metallicFactor=0.0,
111
+ alphaMode='OPAQUE',
112
+ baseColorFactor=baseColorFactor)
113
+
114
+ camera_translation[0] *= -1.
115
+
116
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
117
+ if side_view:
118
+ rot = trimesh.transformations.rotation_matrix(
119
+ np.radians(rot_angle), [0, 1, 0])
120
+ mesh.apply_transform(rot)
121
+ rot = trimesh.transformations.rotation_matrix(
122
+ np.radians(180), [1, 0, 0])
123
+ mesh.apply_transform(rot)
124
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
125
+
126
+ scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
127
+ ambient_light=(0.3, 0.3, 0.3))
128
+ scene.add(mesh, 'mesh')
129
+
130
+ camera_pose = np.eye(4)
131
+ camera_pose[:3, 3] = camera_translation
132
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
133
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
134
+ cx=camera_center[0], cy=camera_center[1])
135
+ scene.add(camera, pose=camera_pose)
136
+
137
+
138
+ light_nodes = create_raymond_lights()
139
+ for node in light_nodes:
140
+ scene.add_node(node)
141
+
142
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
143
+ color = color.astype(np.float32) / 255.0
144
+ valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
145
+ if not side_view:
146
+ output_img = (color[:, :, :3] * valid_mask +
147
+ (1 - valid_mask) * image)
148
+ else:
149
+ output_img = color[:, :, :3]
150
+ if resize is not None:
151
+ output_img = cv2.resize(output_img, resize)
152
+
153
+ output_img = output_img.astype(np.float32)
154
+ renderer.delete()
155
+ return output_img
fourm/utils/hmr2_utils/hmr2/utils/render_openpose.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on the 4DHumans and ProHMR code bases
3
+ # https://github.com/shubham-goel/4D-Humans
4
+ # https://github.com/nkolot/ProHMR
5
+ # --------------------------------------------------------
6
+
7
+ """
8
+ Render OpenPose keypoints.
9
+ Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp
10
+ """
11
+ import cv2
12
+ import math
13
+ import numpy as np
14
+ from typing import List, Tuple
15
+
16
+ def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
17
+ """
18
+ Compute rectangle enclosing keypoints above the threshold.
19
+ Args:
20
+ keypoints (np.array): Keypoint array of shape (N, 3).
21
+ threshold (float): Confidence visualization threshold.
22
+ Returns:
23
+ Tuple[float, float, float]: Rectangle width, height and area.
24
+ """
25
+ valid_ind = keypoints[:, -1] > threshold
26
+ if valid_ind.sum() > 0:
27
+ valid_keypoints = keypoints[valid_ind][:, :-1]
28
+ max_x = valid_keypoints[:,0].max()
29
+ max_y = valid_keypoints[:,1].max()
30
+ min_x = valid_keypoints[:,0].min()
31
+ min_y = valid_keypoints[:,1].min()
32
+ width = max_x - min_x
33
+ height = max_y - min_y
34
+ area = width * height
35
+ return width, height, area
36
+ else:
37
+ return 0,0,0
38
+
39
+ def render_keypoints(img: np.array,
40
+ keypoints: np.array,
41
+ pairs: List,
42
+ colors: List,
43
+ thickness_circle_ratio: float,
44
+ thickness_line_ratio_wrt_circle: float,
45
+ pose_scales: List,
46
+ threshold: float = 0.1) -> np.array:
47
+ """
48
+ Render keypoints on input image.
49
+ Args:
50
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
51
+ keypoints (np.array): Keypoint array of shape (N, 3).
52
+ pairs (List): List of keypoint pairs per limb.
53
+ colors: (List): List of colors per keypoint.
54
+ thickness_circle_ratio (float): Circle thickness ratio.
55
+ thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle.
56
+ pose_scales (List): List of pose scales.
57
+ threshold (float): Only visualize keypoints with confidence above the threshold.
58
+ Returns:
59
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
60
+ """
61
+ img_orig = img.copy()
62
+ width, height = img.shape[1], img.shape[2]
63
+ area = width * height
64
+
65
+ lineType = 8
66
+ shift = 0
67
+ numberColors = len(colors)
68
+ thresholdRectangle = 0.1
69
+
70
+ person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle)
71
+ if person_area > 0:
72
+ ratioAreas = min(1, max(person_width / width, person_height / height))
73
+ thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2)
74
+ thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
75
+ thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle))
76
+ radius = thicknessRatio / 2
77
+
78
+ img = np.ascontiguousarray(img.copy())
79
+ for i, pair in enumerate(pairs):
80
+ index1, index2 = pair
81
+ if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold:
82
+ thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0]))
83
+ colorIndex = index2
84
+ color = colors[colorIndex % numberColors]
85
+ keypoint1 = keypoints[index1, :-1].astype(np.int)
86
+ keypoint2 = keypoints[index2, :-1].astype(np.int)
87
+ cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift)
88
+ for part in range(len(keypoints)):
89
+ faceIndex = part
90
+ if keypoints[faceIndex, -1] > threshold:
91
+ radiusScaled = int(round(radius[faceIndex] * pose_scales[0]))
92
+ thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0]))
93
+ colorIndex = part
94
+ color = colors[colorIndex % numberColors]
95
+ center = keypoints[faceIndex, :-1].astype(np.int)
96
+ cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift)
97
+ return img
98
+
99
+ def render_body_keypoints(img: np.array,
100
+ body_keypoints: np.array) -> np.array:
101
+ """
102
+ Render OpenPose body keypoints on input image.
103
+ Args:
104
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
105
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
106
+ Returns:
107
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
108
+ """
109
+
110
+ thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0])
111
+ thickness_line_ratio_wrt_circle = 0.75
112
+ pairs = []
113
+ pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24]
114
+ pairs = np.array(pairs).reshape(-1,2)
115
+ colors = [255., 0., 85.,
116
+ 255., 0., 0.,
117
+ 255., 85., 0.,
118
+ 255., 170., 0.,
119
+ 255., 255., 0.,
120
+ 170., 255., 0.,
121
+ 85., 255., 0.,
122
+ 0., 255., 0.,
123
+ 255., 0., 0.,
124
+ 0., 255., 85.,
125
+ 0., 255., 170.,
126
+ 0., 255., 255.,
127
+ 0., 170., 255.,
128
+ 0., 85., 255.,
129
+ 0., 0., 255.,
130
+ 255., 0., 170.,
131
+ 170., 0., 255.,
132
+ 255., 0., 255.,
133
+ 85., 0., 255.,
134
+ 0., 0., 255.,
135
+ 0., 0., 255.,
136
+ 0., 0., 255.,
137
+ 0., 255., 255.,
138
+ 0., 255., 255.,
139
+ 0., 255., 255.]
140
+ colors = np.array(colors).reshape(-1,3)
141
+ pose_scales = [1]
142
+ return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
143
+
144
+ def render_openpose(img: np.array,
145
+ body_keypoints: np.array) -> np.array:
146
+ """
147
+ Render keypoints in the OpenPose format on input image.
148
+ Args:
149
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
150
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
151
+ Returns:
152
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
153
+ """
154
+ img = render_body_keypoints(img, body_keypoints)
155
+ return img