Staty commited on
Commit
b818573
·
verified ·
1 Parent(s): 3ea47a7

Upload 22 files

Browse files
data.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torchaudio
13
+ import logging
14
+
15
+ from models.multimodal_preprocessors import SimpleTokenizer
16
+ from PIL import Image
17
+ from pytorchvideo import transforms as pv_transforms
18
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
19
+ from pytorchvideo.data.encoded_video import EncodedVideo
20
+
21
+ from torchvision import transforms
22
+ from torchvision.transforms._transforms_video import NormalizeVideo
23
+
24
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
25
+
26
+ BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
27
+
28
+
29
+ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
30
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
31
+ waveform -= waveform.mean()
32
+ fbank = torchaudio.compliance.kaldi.fbank(
33
+ waveform,
34
+ htk_compat=True,
35
+ sample_frequency=sample_rate,
36
+ use_energy=False,
37
+ window_type="hanning",
38
+ num_mel_bins=num_mel_bins,
39
+ dither=0.0,
40
+ frame_length=25,
41
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
42
+ )
43
+ # Convert to [mel_bins, num_frames] shape
44
+ fbank = fbank.transpose(0, 1)
45
+ # Pad to target_length
46
+ n_frames = fbank.size(1)
47
+ p = target_length - n_frames
48
+ # if p is too large (say >20%), flash a warning
49
+ if abs(p) / n_frames > 0.2:
50
+ logging.warning(
51
+ "Large gap between audio n_frames(%d) and "
52
+ "target_length (%d). Is the audio_target_length "
53
+ "setting correct?",
54
+ n_frames,
55
+ target_length,
56
+ )
57
+ # cut and pad
58
+ if p > 0:
59
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
60
+ elif p < 0:
61
+ fbank = fbank[:, 0:target_length]
62
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
63
+ # channel image
64
+ fbank = fbank.unsqueeze(0)
65
+ return fbank
66
+
67
+
68
+ def get_clip_timepoints(clip_sampler, duration):
69
+ # Read out all clips in this video
70
+ all_clips_timepoints = []
71
+ is_last_clip = False
72
+ end = 0.0
73
+ while not is_last_clip:
74
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
75
+ all_clips_timepoints.append((start, end))
76
+ return all_clips_timepoints
77
+
78
+
79
+ def load_and_transform_vision_data(image_paths, device):
80
+ if image_paths is None:
81
+ return None
82
+
83
+ image_ouputs = []
84
+ for image_path in image_paths:
85
+ data_transform = transforms.Compose(
86
+ [
87
+ transforms.Resize(
88
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
89
+ ),
90
+ transforms.CenterCrop(224),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize(
93
+ mean=(0.48145466, 0.4578275, 0.40821073),
94
+ std=(0.26862954, 0.26130258, 0.27577711),
95
+ ),
96
+ ]
97
+ )
98
+ with open(image_path, "rb") as fopen:
99
+ image = Image.open(fopen).convert("RGB")
100
+
101
+ image = data_transform(image).to(device)
102
+ image_ouputs.append(image)
103
+ return torch.stack(image_ouputs, dim=0)
104
+
105
+
106
+ def load_and_transform_text(text, device):
107
+ if text is None:
108
+ return None
109
+ tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
110
+ tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
111
+ tokens = torch.cat(tokens, dim=0)
112
+ return tokens
113
+
114
+
115
+ def load_and_transform_audio_data(
116
+ audio_paths,
117
+ device,
118
+ num_mel_bins=128,
119
+ target_length=204,
120
+ sample_rate=16000,
121
+ clip_duration=2,
122
+ clips_per_video=3,
123
+ mean=-4.268,
124
+ std=9.138,
125
+ ):
126
+ if audio_paths is None:
127
+ return None
128
+
129
+ audio_outputs = []
130
+ clip_sampler = ConstantClipsPerVideoSampler(
131
+ clip_duration=clip_duration, clips_per_video=clips_per_video
132
+ )
133
+
134
+ for audio_path in audio_paths:
135
+ waveform, sr = torchaudio.load(audio_path)
136
+ if sample_rate != sr:
137
+ waveform = torchaudio.functional.resample(
138
+ waveform, orig_freq=sr, new_freq=sample_rate
139
+ )
140
+ all_clips_timepoints = get_clip_timepoints(
141
+ clip_sampler, waveform.size(1) / sample_rate
142
+ )
143
+ all_clips = []
144
+ for clip_timepoints in all_clips_timepoints:
145
+ waveform_clip = waveform[
146
+ :,
147
+ int(clip_timepoints[0] * sample_rate) : int(
148
+ clip_timepoints[1] * sample_rate
149
+ ),
150
+ ]
151
+ waveform_melspec = waveform2melspec(
152
+ waveform_clip, sample_rate, num_mel_bins, target_length
153
+ )
154
+ all_clips.append(waveform_melspec)
155
+
156
+ normalize = transforms.Normalize(mean=mean, std=std)
157
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
158
+
159
+ all_clips = torch.stack(all_clips, dim=0)
160
+ audio_outputs.append(all_clips)
161
+
162
+ return torch.stack(audio_outputs, dim=0)
163
+
164
+
165
+ def get_clip_timepoints(clip_sampler, duration):
166
+ # Read out all clips in this video
167
+ all_clips_timepoints = []
168
+ is_last_clip = False
169
+ end = 0.0
170
+ while not is_last_clip:
171
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
172
+ all_clips_timepoints.append((start, end))
173
+ return all_clips_timepoints
174
+
175
+
176
+ def crop_boxes(boxes, x_offset, y_offset):
177
+ """
178
+ Peform crop on the bounding boxes given the offsets.
179
+ Args:
180
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
181
+ is `num boxes` x 4.
182
+ x_offset (int): cropping offset in the x axis.
183
+ y_offset (int): cropping offset in the y axis.
184
+ Returns:
185
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
186
+ `num boxes` x 4.
187
+ """
188
+ cropped_boxes = boxes.copy()
189
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
190
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
191
+
192
+ return cropped_boxes
193
+
194
+
195
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
196
+ """
197
+ Perform uniform spatial sampling on the images and corresponding boxes.
198
+ Args:
199
+ images (tensor): images to perform uniform crop. The dimension is
200
+ `num frames` x `channel` x `height` x `width`.
201
+ size (int): size of height and weight to crop the images.
202
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
203
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
204
+ crop if height is larger than width.
205
+ boxes (ndarray or None): optional. Corresponding boxes to images.
206
+ Dimension is `num boxes` x 4.
207
+ scale_size (int): optinal. If not None, resize the images to scale_size before
208
+ performing any crop.
209
+ Returns:
210
+ cropped (tensor): images with dimension of
211
+ `num frames` x `channel` x `size` x `size`.
212
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
213
+ `num boxes` x 4.
214
+ """
215
+ assert spatial_idx in [0, 1, 2]
216
+ ndim = len(images.shape)
217
+ if ndim == 3:
218
+ images = images.unsqueeze(0)
219
+ height = images.shape[2]
220
+ width = images.shape[3]
221
+
222
+ if scale_size is not None:
223
+ if width <= height:
224
+ width, height = scale_size, int(height / width * scale_size)
225
+ else:
226
+ width, height = int(width / height * scale_size), scale_size
227
+ images = torch.nn.functional.interpolate(
228
+ images,
229
+ size=(height, width),
230
+ mode="bilinear",
231
+ align_corners=False,
232
+ )
233
+
234
+ y_offset = int(math.ceil((height - size) / 2))
235
+ x_offset = int(math.ceil((width - size) / 2))
236
+
237
+ if height > width:
238
+ if spatial_idx == 0:
239
+ y_offset = 0
240
+ elif spatial_idx == 2:
241
+ y_offset = height - size
242
+ else:
243
+ if spatial_idx == 0:
244
+ x_offset = 0
245
+ elif spatial_idx == 2:
246
+ x_offset = width - size
247
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
248
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
249
+ if ndim == 3:
250
+ cropped = cropped.squeeze(0)
251
+ return cropped, cropped_boxes
252
+
253
+
254
+ class SpatialCrop(nn.Module):
255
+ """
256
+ Convert the video into 3 smaller clips spatially. Must be used after the
257
+ temporal crops to get spatial crops, and should be used with
258
+ -2 in the spatial crop at the slowfast augmentation stage (so full
259
+ frames are passed in here). Will return a larger list with the
260
+ 3x spatial crops as well.
261
+ """
262
+
263
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
264
+ super().__init__()
265
+ self.crop_size = crop_size
266
+ if num_crops == 3:
267
+ self.crops_to_ext = [0, 1, 2]
268
+ self.flipped_crops_to_ext = []
269
+ elif num_crops == 1:
270
+ self.crops_to_ext = [1]
271
+ self.flipped_crops_to_ext = []
272
+ else:
273
+ raise NotImplementedError("Nothing else supported yet")
274
+
275
+ def forward(self, videos):
276
+ """
277
+ Args:
278
+ videos: A list of C, T, H, W videos.
279
+ Returns:
280
+ videos: A list with 3x the number of elements. Each video converted
281
+ to C, T, H', W' by spatial cropping.
282
+ """
283
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
284
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
285
+ res = []
286
+ for video in videos:
287
+ for spatial_idx in self.crops_to_ext:
288
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
289
+ if not self.flipped_crops_to_ext:
290
+ continue
291
+ flipped_video = transforms.functional.hflip(video)
292
+ for spatial_idx in self.flipped_crops_to_ext:
293
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
294
+ return res
295
+
296
+
297
+ def load_and_transform_video_data(
298
+ video_paths,
299
+ device,
300
+ clip_duration=2,
301
+ clips_per_video=5,
302
+ sample_rate=16000,
303
+ ):
304
+ if video_paths is None:
305
+ return None
306
+
307
+ video_outputs = []
308
+ video_transform = transforms.Compose(
309
+ [
310
+ pv_transforms.ShortSideScale(224),
311
+ NormalizeVideo(
312
+ mean=(0.48145466, 0.4578275, 0.40821073),
313
+ std=(0.26862954, 0.26130258, 0.27577711),
314
+ ),
315
+ ]
316
+ )
317
+
318
+ clip_sampler = ConstantClipsPerVideoSampler(
319
+ clip_duration=clip_duration, clips_per_video=clips_per_video
320
+ )
321
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
322
+
323
+ for video_path in video_paths:
324
+ video = EncodedVideo.from_path(
325
+ video_path,
326
+ decoder="decord",
327
+ decode_audio=False,
328
+ **{"sample_rate": sample_rate},
329
+ )
330
+
331
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
332
+
333
+ all_video = []
334
+ for clip_timepoints in all_clips_timepoints:
335
+ # Read the clip, get frames
336
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
337
+ if clip is None:
338
+ raise ValueError("No clip found")
339
+ video_clip = frame_sampler(clip["video"])
340
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
341
+
342
+ all_video.append(video_clip)
343
+
344
+ all_video = [video_transform(clip) for clip in all_video]
345
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
346
+
347
+ all_video = torch.stack(all_video, dim=0)
348
+ video_outputs.append(all_video)
349
+
350
+ return torch.stack(video_outputs, dim=0).to(device)
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (152 Bytes). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
models/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (5.14 kB). View file
 
models/__pycache__/helpers.cpython-312.pyc ADDED
Binary file (7.84 kB). View file
 
models/__pycache__/helpers.cpython-39.pyc ADDED
Binary file (5.18 kB). View file
 
models/__pycache__/imagebind_model.cpython-310.pyc ADDED
Binary file (8.28 kB). View file
 
models/__pycache__/imagebind_model.cpython-312.pyc ADDED
Binary file (13.4 kB). View file
 
models/__pycache__/imagebind_model.cpython-39.pyc ADDED
Binary file (8.12 kB). View file
 
models/__pycache__/multimodal_preprocessors.cpython-310.pyc ADDED
Binary file (19.9 kB). View file
 
models/__pycache__/multimodal_preprocessors.cpython-312.pyc ADDED
Binary file (33.3 kB). View file
 
models/__pycache__/multimodal_preprocessors.cpython-39.pyc ADDED
Binary file (19.9 kB). View file
 
models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (8 kB). View file
 
models/__pycache__/transformer.cpython-312.pyc ADDED
Binary file (12.6 kB). View file
 
models/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (7.86 kB). View file
 
models/helpers.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+
10
+ import einops
11
+ import numpy as np
12
+ import torch
13
+
14
+ import torch.nn as nn
15
+
16
+
17
+ class Normalize(nn.Module):
18
+ def __init__(self, dim: int) -> None:
19
+ super().__init__()
20
+ self.dim = dim
21
+
22
+ def forward(self, x):
23
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
24
+
25
+
26
+ class LearnableLogitScaling(nn.Module):
27
+ def __init__(
28
+ self,
29
+ logit_scale_init: float = 1 / 0.07,
30
+ learnable: bool = True,
31
+ max_logit_scale: float = 100,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.max_logit_scale = max_logit_scale
35
+ self.logit_scale_init = logit_scale_init
36
+ self.learnable = learnable
37
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
38
+ if learnable:
39
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
40
+ else:
41
+ self.register_buffer("log_logit_scale", log_logit_scale)
42
+
43
+ def forward(self, x):
44
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
45
+
46
+ def extra_repr(self):
47
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
48
+ return st
49
+
50
+
51
+ class EinOpsRearrange(nn.Module):
52
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
53
+ super().__init__()
54
+ self.rearrange_expr = rearrange_expr
55
+ self.kwargs = kwargs
56
+
57
+ def forward(self, x):
58
+ assert isinstance(x, torch.Tensor)
59
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
60
+
61
+
62
+ class VerboseNNModule(nn.Module):
63
+ """
64
+ Wrapper around nn.Module that prints registered buffers and parameter names.
65
+ """
66
+
67
+ @staticmethod
68
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
69
+ st = (
70
+ "("
71
+ + name
72
+ + "): "
73
+ + "tensor("
74
+ + str(tuple(tensor[1].shape))
75
+ + ", requires_grad="
76
+ + str(tensor[1].requires_grad)
77
+ + ")\n"
78
+ )
79
+ return st
80
+
81
+ def extra_repr(self) -> str:
82
+ named_modules = set()
83
+ for p in self.named_modules():
84
+ named_modules.update([p[0]])
85
+ named_modules = list(named_modules)
86
+
87
+ string_repr = ""
88
+ for p in self.named_parameters():
89
+ name = p[0].split(".")[0]
90
+ if name not in named_modules:
91
+ string_repr += self.get_readable_tensor_repr(name, p)
92
+
93
+ for p in self.named_buffers():
94
+ name = p[0].split(".")[0]
95
+ string_repr += self.get_readable_tensor_repr(name, p)
96
+
97
+ return string_repr
98
+
99
+
100
+ def cast_if_src_dtype(
101
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
102
+ ):
103
+ updated = False
104
+ if tensor.dtype == src_dtype:
105
+ tensor = tensor.to(dtype=tgt_dtype)
106
+ updated = True
107
+ return tensor, updated
108
+
109
+
110
+ class QuickGELU(nn.Module):
111
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
112
+ def forward(self, x: torch.Tensor):
113
+ return x * torch.sigmoid(1.702 * x)
114
+
115
+
116
+ class SelectElement(nn.Module):
117
+ def __init__(self, index) -> None:
118
+ super().__init__()
119
+ self.index = index
120
+
121
+ def forward(self, x):
122
+ assert x.ndim >= 3
123
+ return x[:, self.index, ...]
124
+
125
+
126
+ class SelectEOSAndProject(nn.Module):
127
+ """
128
+ Text Pooling used in OpenCLIP
129
+ """
130
+
131
+ def __init__(self, proj: nn.Module) -> None:
132
+ super().__init__()
133
+ self.proj = proj
134
+
135
+ def forward(self, x, seq_len):
136
+ assert x.ndim == 3
137
+ # x is of shape B x L x D
138
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
139
+ x = x[torch.arange(x.shape[0]), seq_len]
140
+ x = self.proj(x)
141
+ return x
models/imagebind_model.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import os
10
+ import urllib
11
+ from functools import partial
12
+ from types import SimpleNamespace
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from models.helpers import (
18
+ EinOpsRearrange,
19
+ LearnableLogitScaling,
20
+ Normalize,
21
+ SelectElement,
22
+ SelectEOSAndProject,
23
+ )
24
+ from models.multimodal_preprocessors import (
25
+ AudioPreprocessor,
26
+ IMUPreprocessor,
27
+ PadIm2Video,
28
+ PatchEmbedGeneric,
29
+ RGBDTPreprocessor,
30
+ SpatioTemporalPosEmbeddingHelper,
31
+ TextPreprocessor,
32
+ ThermalPreprocessor,
33
+ )
34
+
35
+ from models.transformer import MultiheadAttention, SimpleTransformer
36
+
37
+
38
+ ModalityType = SimpleNamespace(
39
+ VISION="vision",
40
+ TEXT="text",
41
+ AUDIO="audio",
42
+ THERMAL="thermal",
43
+ DEPTH="depth",
44
+ IMU="imu",
45
+ )
46
+
47
+
48
+ class ImageBindModel(nn.Module):
49
+ def __init__(
50
+ self,
51
+ video_frames=2,
52
+ kernel_size=(2, 14, 14),
53
+ audio_kernel_size=16,
54
+ audio_stride=10,
55
+ out_embed_dim=768,
56
+ vision_embed_dim=1024,
57
+ vision_num_blocks=24,
58
+ vision_num_heads=16,
59
+ audio_embed_dim=768,
60
+ audio_num_blocks=12,
61
+ audio_num_heads=12,
62
+ audio_num_mel_bins=128,
63
+ audio_target_len=204,
64
+ audio_drop_path=0.1,
65
+ text_embed_dim=768,
66
+ text_num_blocks=12,
67
+ text_num_heads=12,
68
+ depth_embed_dim=384,
69
+ depth_kernel_size=16,
70
+ depth_num_blocks=12,
71
+ depth_num_heads=8,
72
+ depth_drop_path=0.0,
73
+ thermal_embed_dim=768,
74
+ thermal_kernel_size=16,
75
+ thermal_num_blocks=12,
76
+ thermal_num_heads=12,
77
+ thermal_drop_path=0.0,
78
+ imu_embed_dim=512,
79
+ imu_kernel_size=8,
80
+ imu_num_blocks=6,
81
+ imu_num_heads=8,
82
+ imu_drop_path=0.7,
83
+ ):
84
+ super().__init__()
85
+
86
+ self.modality_preprocessors = self._create_modality_preprocessors(
87
+ video_frames,
88
+ vision_embed_dim,
89
+ kernel_size,
90
+ text_embed_dim,
91
+ audio_embed_dim,
92
+ audio_kernel_size,
93
+ audio_stride,
94
+ audio_num_mel_bins,
95
+ audio_target_len,
96
+ depth_embed_dim,
97
+ depth_kernel_size,
98
+ thermal_embed_dim,
99
+ thermal_kernel_size,
100
+ imu_embed_dim,
101
+ )
102
+
103
+ self.modality_trunks = self._create_modality_trunks(
104
+ vision_embed_dim,
105
+ vision_num_blocks,
106
+ vision_num_heads,
107
+ text_embed_dim,
108
+ text_num_blocks,
109
+ text_num_heads,
110
+ audio_embed_dim,
111
+ audio_num_blocks,
112
+ audio_num_heads,
113
+ audio_drop_path,
114
+ depth_embed_dim,
115
+ depth_num_blocks,
116
+ depth_num_heads,
117
+ depth_drop_path,
118
+ thermal_embed_dim,
119
+ thermal_num_blocks,
120
+ thermal_num_heads,
121
+ thermal_drop_path,
122
+ imu_embed_dim,
123
+ imu_num_blocks,
124
+ imu_num_heads,
125
+ imu_drop_path,
126
+ )
127
+
128
+ self.modality_heads = self._create_modality_heads(
129
+ out_embed_dim,
130
+ vision_embed_dim,
131
+ text_embed_dim,
132
+ audio_embed_dim,
133
+ depth_embed_dim,
134
+ thermal_embed_dim,
135
+ imu_embed_dim,
136
+ )
137
+
138
+ self.modality_postprocessors = self._create_modality_postprocessors(
139
+ out_embed_dim
140
+ )
141
+
142
+ def _create_modality_preprocessors(
143
+ self,
144
+ video_frames=2,
145
+ vision_embed_dim=1024,
146
+ kernel_size=(2, 14, 14),
147
+ text_embed_dim=768,
148
+ audio_embed_dim=768,
149
+ audio_kernel_size=16,
150
+ audio_stride=10,
151
+ audio_num_mel_bins=128,
152
+ audio_target_len=204,
153
+ depth_embed_dim=768,
154
+ depth_kernel_size=16,
155
+ thermal_embed_dim=768,
156
+ thermal_kernel_size=16,
157
+ imu_embed_dim=512,
158
+ ):
159
+ rgbt_stem = PatchEmbedGeneric(
160
+ proj_stem=[
161
+ PadIm2Video(pad_type="repeat", ntimes=2),
162
+ nn.Conv3d(
163
+ in_channels=3,
164
+ kernel_size=kernel_size,
165
+ out_channels=vision_embed_dim,
166
+ stride=kernel_size,
167
+ bias=False,
168
+ ),
169
+ ]
170
+ )
171
+ rgbt_preprocessor = RGBDTPreprocessor(
172
+ img_size=[3, video_frames, 224, 224],
173
+ num_cls_tokens=1,
174
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
175
+ rgbt_stem=rgbt_stem,
176
+ depth_stem=None,
177
+ )
178
+
179
+ text_preprocessor = TextPreprocessor(
180
+ context_length=77,
181
+ vocab_size=49408,
182
+ embed_dim=text_embed_dim,
183
+ causal_masking=True,
184
+ )
185
+
186
+ audio_stem = PatchEmbedGeneric(
187
+ proj_stem=[
188
+ nn.Conv2d(
189
+ in_channels=1,
190
+ kernel_size=audio_kernel_size,
191
+ stride=audio_stride,
192
+ out_channels=audio_embed_dim,
193
+ bias=False,
194
+ ),
195
+ ],
196
+ norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
197
+ )
198
+ audio_preprocessor = AudioPreprocessor(
199
+ img_size=[1, audio_num_mel_bins, audio_target_len],
200
+ num_cls_tokens=1,
201
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
202
+ audio_stem=audio_stem,
203
+ )
204
+
205
+ depth_stem = PatchEmbedGeneric(
206
+ [
207
+ nn.Conv2d(
208
+ kernel_size=depth_kernel_size,
209
+ in_channels=1,
210
+ out_channels=depth_embed_dim,
211
+ stride=depth_kernel_size,
212
+ bias=False,
213
+ ),
214
+ ],
215
+ norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
216
+ )
217
+
218
+ depth_preprocessor = RGBDTPreprocessor(
219
+ img_size=[1, 224, 224],
220
+ num_cls_tokens=1,
221
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
222
+ rgbt_stem=None,
223
+ depth_stem=depth_stem,
224
+ )
225
+
226
+ thermal_stem = PatchEmbedGeneric(
227
+ [
228
+ nn.Conv2d(
229
+ kernel_size=thermal_kernel_size,
230
+ in_channels=1,
231
+ out_channels=thermal_embed_dim,
232
+ stride=thermal_kernel_size,
233
+ bias=False,
234
+ ),
235
+ ],
236
+ norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
237
+ )
238
+ thermal_preprocessor = ThermalPreprocessor(
239
+ img_size=[1, 224, 224],
240
+ num_cls_tokens=1,
241
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
242
+ thermal_stem=thermal_stem,
243
+ )
244
+
245
+ imu_stem = PatchEmbedGeneric(
246
+ [
247
+ nn.Linear(
248
+ in_features=48,
249
+ out_features=imu_embed_dim,
250
+ bias=False,
251
+ ),
252
+ ],
253
+ norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
254
+ )
255
+
256
+ imu_preprocessor = IMUPreprocessor(
257
+ img_size=[6, 2000],
258
+ num_cls_tokens=1,
259
+ kernel_size=8,
260
+ embed_dim=imu_embed_dim,
261
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
262
+ imu_stem=imu_stem,
263
+ )
264
+
265
+ modality_preprocessors = {
266
+ ModalityType.VISION: rgbt_preprocessor,
267
+ ModalityType.TEXT: text_preprocessor,
268
+ ModalityType.AUDIO: audio_preprocessor,
269
+ ModalityType.DEPTH: depth_preprocessor,
270
+ ModalityType.THERMAL: thermal_preprocessor,
271
+ ModalityType.IMU: imu_preprocessor,
272
+ }
273
+
274
+ return nn.ModuleDict(modality_preprocessors)
275
+
276
+ def _create_modality_trunks(
277
+ self,
278
+ vision_embed_dim=1024,
279
+ vision_num_blocks=24,
280
+ vision_num_heads=16,
281
+ text_embed_dim=768,
282
+ text_num_blocks=12,
283
+ text_num_heads=12,
284
+ audio_embed_dim=768,
285
+ audio_num_blocks=12,
286
+ audio_num_heads=12,
287
+ audio_drop_path=0.0,
288
+ depth_embed_dim=768,
289
+ depth_num_blocks=12,
290
+ depth_num_heads=12,
291
+ depth_drop_path=0.0,
292
+ thermal_embed_dim=768,
293
+ thermal_num_blocks=12,
294
+ thermal_num_heads=12,
295
+ thermal_drop_path=0.0,
296
+ imu_embed_dim=512,
297
+ imu_num_blocks=6,
298
+ imu_num_heads=8,
299
+ imu_drop_path=0.7,
300
+ ):
301
+ def instantiate_trunk(
302
+ embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
303
+ ):
304
+ return SimpleTransformer(
305
+ embed_dim=embed_dim,
306
+ num_blocks=num_blocks,
307
+ ffn_dropout_rate=0.0,
308
+ drop_path_rate=drop_path,
309
+ attn_target=partial(
310
+ MultiheadAttention,
311
+ embed_dim=embed_dim,
312
+ num_heads=num_heads,
313
+ bias=True,
314
+ add_bias_kv=add_bias_kv,
315
+ ),
316
+ pre_transformer_layer=nn.Sequential(
317
+ nn.LayerNorm(embed_dim, eps=1e-6)
318
+ if pre_transformer_ln
319
+ else nn.Identity(),
320
+ EinOpsRearrange("b l d -> l b d"),
321
+ ),
322
+ post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
323
+ )
324
+
325
+ modality_trunks = {}
326
+ modality_trunks[ModalityType.VISION] = instantiate_trunk(
327
+ vision_embed_dim,
328
+ vision_num_blocks,
329
+ vision_num_heads,
330
+ pre_transformer_ln=True,
331
+ add_bias_kv=False,
332
+ drop_path=0.0,
333
+ )
334
+ modality_trunks[ModalityType.TEXT] = instantiate_trunk(
335
+ text_embed_dim,
336
+ text_num_blocks,
337
+ text_num_heads,
338
+ pre_transformer_ln=False,
339
+ add_bias_kv=False,
340
+ drop_path=0.0,
341
+ )
342
+ modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
343
+ audio_embed_dim,
344
+ audio_num_blocks,
345
+ audio_num_heads,
346
+ pre_transformer_ln=False,
347
+ add_bias_kv=True,
348
+ drop_path=audio_drop_path,
349
+ )
350
+ modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
351
+ depth_embed_dim,
352
+ depth_num_blocks,
353
+ depth_num_heads,
354
+ pre_transformer_ln=False,
355
+ add_bias_kv=True,
356
+ drop_path=depth_drop_path,
357
+ )
358
+ modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
359
+ thermal_embed_dim,
360
+ thermal_num_blocks,
361
+ thermal_num_heads,
362
+ pre_transformer_ln=False,
363
+ add_bias_kv=True,
364
+ drop_path=thermal_drop_path,
365
+ )
366
+ modality_trunks[ModalityType.IMU] = instantiate_trunk(
367
+ imu_embed_dim,
368
+ imu_num_blocks,
369
+ imu_num_heads,
370
+ pre_transformer_ln=False,
371
+ add_bias_kv=True,
372
+ drop_path=imu_drop_path,
373
+ )
374
+
375
+ return nn.ModuleDict(modality_trunks)
376
+
377
+ def _create_modality_heads(
378
+ self,
379
+ out_embed_dim,
380
+ vision_embed_dim,
381
+ text_embed_dim,
382
+ audio_embed_dim,
383
+ depth_embed_dim,
384
+ thermal_embed_dim,
385
+ imu_embed_dim,
386
+ ):
387
+ modality_heads = {}
388
+
389
+ modality_heads[ModalityType.VISION] = nn.Sequential(
390
+ nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
391
+ SelectElement(index=0),
392
+ nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
393
+ )
394
+
395
+ modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
396
+ proj=nn.Sequential(
397
+ nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
398
+ nn.Linear(text_embed_dim, out_embed_dim, bias=False),
399
+ )
400
+ )
401
+
402
+ modality_heads[ModalityType.AUDIO] = nn.Sequential(
403
+ nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
404
+ SelectElement(index=0),
405
+ nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
406
+ )
407
+
408
+ modality_heads[ModalityType.DEPTH] = nn.Sequential(
409
+ nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
410
+ SelectElement(index=0),
411
+ nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
412
+ )
413
+
414
+ modality_heads[ModalityType.THERMAL] = nn.Sequential(
415
+ nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
416
+ SelectElement(index=0),
417
+ nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
418
+ )
419
+
420
+ modality_heads[ModalityType.IMU] = nn.Sequential(
421
+ nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
422
+ SelectElement(index=0),
423
+ nn.Dropout(p=0.5),
424
+ nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
425
+ )
426
+
427
+ return nn.ModuleDict(modality_heads)
428
+
429
+ def _create_modality_postprocessors(self, out_embed_dim):
430
+ modality_postprocessors = {}
431
+
432
+ modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
433
+ modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
434
+ Normalize(dim=-1), LearnableLogitScaling(learnable=True)
435
+ )
436
+ modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
437
+ Normalize(dim=-1),
438
+ LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
439
+ )
440
+ modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
441
+ Normalize(dim=-1),
442
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
443
+ )
444
+ modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
445
+ Normalize(dim=-1),
446
+ LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
447
+ )
448
+ modality_postprocessors[ModalityType.IMU] = nn.Sequential(
449
+ Normalize(dim=-1),
450
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
451
+ )
452
+
453
+ return nn.ModuleDict(modality_postprocessors)
454
+
455
+ def forward(self, inputs):
456
+ outputs = {}
457
+ for modality_key, modality_value in inputs.items():
458
+ reduce_list = (
459
+ modality_value.ndim >= 5
460
+ ) # Audio and Video inputs consist of multiple clips
461
+ if reduce_list:
462
+ B, S = modality_value.shape[:2]
463
+ modality_value = modality_value.reshape(
464
+ B * S, *modality_value.shape[2:]
465
+ )
466
+
467
+ if modality_value is not None:
468
+ modality_value = self.modality_preprocessors[modality_key](
469
+ **{modality_key: modality_value}
470
+ )
471
+ trunk_inputs = modality_value["trunk"]
472
+ head_inputs = modality_value["head"]
473
+ modality_value = self.modality_trunks[modality_key](**trunk_inputs)
474
+ modality_value = self.modality_heads[modality_key](
475
+ modality_value, **head_inputs
476
+ )
477
+ modality_value = self.modality_postprocessors[modality_key](
478
+ modality_value
479
+ )
480
+
481
+ if reduce_list:
482
+ modality_value = modality_value.reshape(B, S, -1)
483
+ modality_value = modality_value.mean(dim=1)
484
+
485
+ outputs[modality_key] = modality_value
486
+
487
+ return outputs
488
+
489
+
490
+ def imagebind_huge(pretrained=False):
491
+ model = ImageBindModel(
492
+ vision_embed_dim=1280,
493
+ vision_num_blocks=32,
494
+ vision_num_heads=16,
495
+ text_embed_dim=1024,
496
+ text_num_blocks=24,
497
+ text_num_heads=16,
498
+ out_embed_dim=1024,
499
+ audio_drop_path=0.1,
500
+ imu_drop_path=0.7,
501
+ )
502
+
503
+ if pretrained:
504
+ if not os.path.exists(".checkpoints/imagebind_huge.pth"):
505
+ print(
506
+ "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
507
+ )
508
+ os.makedirs(".checkpoints", exist_ok=True)
509
+ torch.hub.download_url_to_file(
510
+ "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
511
+ ".checkpoints/imagebind_huge.pth",
512
+ progress=True,
513
+ )
514
+
515
+ model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"))
516
+
517
+ return model
models/multimodal_preprocessors.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import gzip
9
+ import html
10
+ import io
11
+ import math
12
+ from functools import lru_cache
13
+ from typing import Callable, List, Optional
14
+
15
+ import ftfy
16
+
17
+ import numpy as np
18
+ import regex as re
19
+ import torch
20
+ import torch.nn as nn
21
+ from iopath.common.file_io import g_pathmgr
22
+ from timm.models.layers import trunc_normal_
23
+
24
+ from models.helpers import cast_if_src_dtype, VerboseNNModule
25
+
26
+
27
+ def get_sinusoid_encoding_table(n_position, d_hid):
28
+ """Sinusoid position encoding table"""
29
+
30
+ # TODO: make it with torch instead of numpy
31
+ def get_position_angle_vec(position):
32
+ return [
33
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
34
+ for hid_j in range(d_hid)
35
+ ]
36
+
37
+ sinusoid_table = np.array(
38
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
39
+ )
40
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
41
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
42
+
43
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
44
+
45
+
46
+ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
47
+ N = pos_embed.shape[1]
48
+ if N == target_spatial_size:
49
+ return pos_embed
50
+ dim = pos_embed.shape[-1]
51
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
52
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
53
+ pos_embed = nn.functional.interpolate(
54
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
55
+ 0, 3, 1, 2
56
+ ),
57
+ scale_factor=math.sqrt(target_spatial_size / N),
58
+ mode="bicubic",
59
+ )
60
+ if updated:
61
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
62
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
63
+ return pos_embed
64
+
65
+
66
+ def interpolate_pos_encoding(
67
+ npatch_per_img,
68
+ pos_embed,
69
+ patches_layout,
70
+ input_shape=None,
71
+ first_patch_idx=1,
72
+ ):
73
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
74
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
75
+ if npatch_per_img == N:
76
+ return pos_embed
77
+
78
+ assert (
79
+ patches_layout[-1] == patches_layout[-2]
80
+ ), "Interpolation of pos embed not supported for non-square layouts"
81
+
82
+ class_emb = pos_embed[:, :first_patch_idx]
83
+ pos_embed = pos_embed[:, first_patch_idx:]
84
+
85
+ if input_shape is None or patches_layout[0] == 1:
86
+ # simple 2D pos embedding, no temporal component
87
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
88
+ elif patches_layout[0] > 1:
89
+ # pos embed has a temporal component
90
+ assert len(input_shape) == 4, "temporal interpolation not supported"
91
+ # we only support 2D interpolation in this case
92
+ num_frames = patches_layout[0]
93
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
94
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
95
+ # interpolate embedding for zeroth frame
96
+ pos_embed = interpolate_pos_encoding_2d(
97
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
98
+ )
99
+ else:
100
+ raise ValueError("This type of interpolation isn't implemented")
101
+
102
+ return torch.cat((class_emb, pos_embed), dim=1)
103
+
104
+
105
+ def _get_pos_embedding(
106
+ npatch_per_img,
107
+ pos_embed,
108
+ patches_layout,
109
+ input_shape,
110
+ first_patch_idx=1,
111
+ ):
112
+ pos_embed = interpolate_pos_encoding(
113
+ npatch_per_img,
114
+ pos_embed,
115
+ patches_layout,
116
+ input_shape=input_shape,
117
+ first_patch_idx=first_patch_idx,
118
+ )
119
+ return pos_embed
120
+
121
+
122
+ class PatchEmbedGeneric(nn.Module):
123
+ """
124
+ PatchEmbed from Hydra
125
+ """
126
+
127
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
128
+ super().__init__()
129
+
130
+ if len(proj_stem) > 1:
131
+ self.proj = nn.Sequential(*proj_stem)
132
+ else:
133
+ # Special case to be able to load pre-trained models that were
134
+ # trained with a standard stem
135
+ self.proj = proj_stem[0]
136
+ self.norm_layer = norm_layer
137
+
138
+ def get_patch_layout(self, img_size):
139
+ with torch.no_grad():
140
+ dummy_img = torch.zeros(
141
+ [
142
+ 1,
143
+ ]
144
+ + img_size
145
+ )
146
+ dummy_out = self.proj(dummy_img)
147
+ embed_dim = dummy_out.shape[1]
148
+ patches_layout = tuple(dummy_out.shape[2:])
149
+ num_patches = np.prod(patches_layout)
150
+ return patches_layout, num_patches, embed_dim
151
+
152
+ def forward(self, x):
153
+ x = self.proj(x)
154
+ # B C (T) H W -> B (T)HW C
155
+ x = x.flatten(2).transpose(1, 2)
156
+ if self.norm_layer is not None:
157
+ x = self.norm_layer(x)
158
+ return x
159
+
160
+
161
+ class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
162
+ def __init__(
163
+ self,
164
+ patches_layout: List,
165
+ num_patches: int,
166
+ num_cls_tokens: int,
167
+ embed_dim: int,
168
+ learnable: bool,
169
+ ) -> None:
170
+ super().__init__()
171
+ self.num_cls_tokens = num_cls_tokens
172
+ self.patches_layout = patches_layout
173
+ self.num_patches = num_patches
174
+ self.num_tokens = num_cls_tokens + num_patches
175
+ self.learnable = learnable
176
+ if self.learnable:
177
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
178
+ trunc_normal_(self.pos_embed, std=0.02)
179
+ else:
180
+ self.register_buffer(
181
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
182
+ )
183
+
184
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
185
+ input_shape = vision_input.shape
186
+ pos_embed = _get_pos_embedding(
187
+ all_vision_tokens.size(1) - self.num_cls_tokens,
188
+ pos_embed=self.pos_embed,
189
+ patches_layout=self.patches_layout,
190
+ input_shape=input_shape,
191
+ first_patch_idx=self.num_cls_tokens,
192
+ )
193
+ return pos_embed
194
+
195
+
196
+ class RGBDTPreprocessor(VerboseNNModule):
197
+ def __init__(
198
+ self,
199
+ rgbt_stem: PatchEmbedGeneric,
200
+ depth_stem: PatchEmbedGeneric,
201
+ img_size: List = (3, 224, 224),
202
+ num_cls_tokens: int = 1,
203
+ pos_embed_fn: Callable = None,
204
+ use_type_embed: bool = False,
205
+ init_param_style: str = "openclip",
206
+ ) -> None:
207
+ super().__init__()
208
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
209
+ (
210
+ self.patches_layout,
211
+ self.num_patches,
212
+ self.embed_dim,
213
+ ) = stem.get_patch_layout(img_size)
214
+ self.rgbt_stem = rgbt_stem
215
+ self.depth_stem = depth_stem
216
+ self.use_pos_embed = pos_embed_fn is not None
217
+ self.use_type_embed = use_type_embed
218
+ self.num_cls_tokens = num_cls_tokens
219
+
220
+ if self.use_pos_embed:
221
+ self.pos_embedding_helper = pos_embed_fn(
222
+ patches_layout=self.patches_layout,
223
+ num_cls_tokens=num_cls_tokens,
224
+ num_patches=self.num_patches,
225
+ embed_dim=self.embed_dim,
226
+ )
227
+ if self.num_cls_tokens > 0:
228
+ self.cls_token = nn.Parameter(
229
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
230
+ )
231
+ if self.use_type_embed:
232
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
233
+
234
+ self.init_parameters(init_param_style)
235
+
236
+ @torch.no_grad()
237
+ def init_parameters(self, init_param_style):
238
+ if init_param_style == "openclip":
239
+ # OpenCLIP style initialization
240
+ scale = self.embed_dim**-0.5
241
+ if self.use_pos_embed:
242
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
243
+ self.pos_embedding_helper.pos_embed *= scale
244
+
245
+ if self.num_cls_tokens > 0:
246
+ nn.init.normal_(self.cls_token)
247
+ self.cls_token *= scale
248
+ elif init_param_style == "vit":
249
+ self.cls_token.data.fill_(0)
250
+ else:
251
+ raise ValueError(f"Unknown init {init_param_style}")
252
+
253
+ if self.use_type_embed:
254
+ nn.init.normal_(self.type_embed)
255
+
256
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
257
+ # tokens is of shape B x L x D
258
+ tokens = stem(input)
259
+ assert tokens.ndim == 3
260
+ assert tokens.shape[2] == self.embed_dim
261
+ B = tokens.shape[0]
262
+ if self.num_cls_tokens > 0:
263
+ class_tokens = self.cls_token.expand(
264
+ B, -1, -1
265
+ ) # stole class_tokens impl from Phil Wang, thanks
266
+ tokens = torch.cat((class_tokens, tokens), dim=1)
267
+ if self.use_pos_embed:
268
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
269
+ tokens = tokens + pos_embed
270
+ if self.use_type_embed:
271
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
272
+ return tokens
273
+
274
+ def forward(self, vision=None, depth=None, patch_mask=None):
275
+ if patch_mask is not None:
276
+ raise NotImplementedError()
277
+
278
+ if vision is not None:
279
+ vision_tokens = self.tokenize_input_and_cls_pos(
280
+ vision, self.rgbt_stem, patch_mask
281
+ )
282
+
283
+ if depth is not None:
284
+ depth_tokens = self.tokenize_input_and_cls_pos(
285
+ depth, self.depth_stem, patch_mask
286
+ )
287
+
288
+ # aggregate tokens
289
+ if vision is not None and depth is not None:
290
+ final_tokens = vision_tokens + depth_tokens
291
+ else:
292
+ final_tokens = vision_tokens if vision is not None else depth_tokens
293
+ return_dict = {
294
+ "trunk": {
295
+ "tokens": final_tokens,
296
+ },
297
+ "head": {},
298
+ }
299
+ return return_dict
300
+
301
+
302
+ class AudioPreprocessor(RGBDTPreprocessor):
303
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
304
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
305
+
306
+ def forward(self, audio=None):
307
+ return super().forward(vision=audio)
308
+
309
+
310
+ class ThermalPreprocessor(RGBDTPreprocessor):
311
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
312
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
313
+
314
+ def forward(self, thermal=None):
315
+ return super().forward(vision=thermal)
316
+
317
+
318
+ def build_causal_attention_mask(context_length):
319
+ # lazily create causal attention mask, with full attention between the vision tokens
320
+ # pytorch uses additive attention mask; fill with -inf
321
+ mask = torch.empty(context_length, context_length, requires_grad=False)
322
+ mask.fill_(float("-inf"))
323
+ mask.triu_(1) # zero out the lower diagonal
324
+ return mask
325
+
326
+
327
+ class TextPreprocessor(VerboseNNModule):
328
+ def __init__(
329
+ self,
330
+ vocab_size: int,
331
+ context_length: int,
332
+ embed_dim: int,
333
+ causal_masking: bool,
334
+ supply_seq_len_to_head: bool = True,
335
+ num_cls_tokens: int = 0,
336
+ init_param_style: str = "openclip",
337
+ ) -> None:
338
+ super().__init__()
339
+ self.vocab_size = vocab_size
340
+ self.context_length = context_length
341
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
342
+ self.pos_embed = nn.Parameter(
343
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
344
+ )
345
+ self.causal_masking = causal_masking
346
+ if self.causal_masking:
347
+ mask = build_causal_attention_mask(self.context_length)
348
+ # register the mask as a buffer so it can be moved to the right device
349
+ self.register_buffer("mask", mask)
350
+
351
+ self.supply_seq_len_to_head = supply_seq_len_to_head
352
+ self.num_cls_tokens = num_cls_tokens
353
+ self.embed_dim = embed_dim
354
+ if num_cls_tokens > 0:
355
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
356
+ self.cls_token = nn.Parameter(
357
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
358
+ )
359
+
360
+ self.init_parameters(init_param_style)
361
+
362
+ @torch.no_grad()
363
+ def init_parameters(self, init_param_style="openclip"):
364
+ # OpenCLIP style initialization
365
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
366
+ nn.init.normal_(self.pos_embed, std=0.01)
367
+
368
+ if init_param_style == "openclip":
369
+ # OpenCLIP style initialization
370
+ scale = self.embed_dim**-0.5
371
+ if self.num_cls_tokens > 0:
372
+ nn.init.normal_(self.cls_token)
373
+ self.cls_token *= scale
374
+ elif init_param_style == "vit":
375
+ self.cls_token.data.fill_(0)
376
+ else:
377
+ raise ValueError(f"Unknown init {init_param_style}")
378
+
379
+ def forward(self, text):
380
+ # text tokens are of shape B x L x D
381
+ text_tokens = self.token_embedding(text)
382
+ # concat CLS tokens if any
383
+ if self.num_cls_tokens > 0:
384
+ B = text_tokens.shape[0]
385
+ class_tokens = self.cls_token.expand(
386
+ B, -1, -1
387
+ ) # stole class_tokens impl from Phil Wang, thanks
388
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
389
+ text_tokens = text_tokens + self.pos_embed
390
+ return_dict = {
391
+ "trunk": {
392
+ "tokens": text_tokens,
393
+ },
394
+ "head": {},
395
+ }
396
+ # Compute sequence length after adding CLS tokens
397
+ if self.supply_seq_len_to_head:
398
+ text_lengths = text.argmax(dim=-1)
399
+ return_dict["head"] = {
400
+ "seq_len": text_lengths,
401
+ }
402
+ if self.causal_masking:
403
+ return_dict["trunk"].update({"attn_mask": self.mask})
404
+ return return_dict
405
+
406
+
407
+ class Im2Video(nn.Module):
408
+ """Convert an image into a trivial video."""
409
+
410
+ def __init__(self, time_dim=2):
411
+ super().__init__()
412
+ self.time_dim = time_dim
413
+
414
+ def forward(self, x):
415
+ if x.ndim == 4:
416
+ # B, C, H, W -> B, C, T, H, W
417
+ return x.unsqueeze(self.time_dim)
418
+ elif x.ndim == 5:
419
+ return x
420
+ else:
421
+ raise ValueError(f"Dimension incorrect {x.shape}")
422
+
423
+
424
+ class PadIm2Video(Im2Video):
425
+ def __init__(self, ntimes, pad_type, time_dim=2):
426
+ super().__init__(time_dim=time_dim)
427
+ assert ntimes > 0
428
+ assert pad_type in ["zero", "repeat"]
429
+ self.ntimes = ntimes
430
+ self.pad_type = pad_type
431
+
432
+ def forward(self, x):
433
+ x = super().forward(x)
434
+ if x.shape[self.time_dim] == 1:
435
+ if self.pad_type == "repeat":
436
+ new_shape = [1] * len(x.shape)
437
+ new_shape[self.time_dim] = self.ntimes
438
+ x = x.repeat(new_shape)
439
+ elif self.pad_type == "zero":
440
+ padarg = [0, 0] * len(x.shape)
441
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
442
+ x = nn.functional.pad(x, padarg)
443
+ return x
444
+
445
+
446
+ # Modified from github.com/openai/CLIP
447
+ @lru_cache()
448
+ def bytes_to_unicode():
449
+ """
450
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
451
+ The reversible bpe codes work on unicode strings.
452
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
453
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
454
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
455
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
456
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
457
+ """
458
+ bs = (
459
+ list(range(ord("!"), ord("~") + 1))
460
+ + list(range(ord("¡"), ord("¬") + 1))
461
+ + list(range(ord("®"), ord("ÿ") + 1))
462
+ )
463
+ cs = bs[:]
464
+ n = 0
465
+ for b in range(2**8):
466
+ if b not in bs:
467
+ bs.append(b)
468
+ cs.append(2**8 + n)
469
+ n += 1
470
+ cs = [chr(n) for n in cs]
471
+ return dict(zip(bs, cs))
472
+
473
+
474
+ def get_pairs(word):
475
+ """Return set of symbol pairs in a word.
476
+ Word is represented as tuple of symbols (symbols being variable-length strings).
477
+ """
478
+ pairs = set()
479
+ prev_char = word[0]
480
+ for char in word[1:]:
481
+ pairs.add((prev_char, char))
482
+ prev_char = char
483
+ return pairs
484
+
485
+
486
+ def basic_clean(text):
487
+ text = ftfy.fix_text(text)
488
+ text = html.unescape(html.unescape(text))
489
+ return text.strip()
490
+
491
+
492
+ def whitespace_clean(text):
493
+ text = re.sub(r"\s+", " ", text)
494
+ text = text.strip()
495
+ return text
496
+
497
+
498
+ class SimpleTokenizer(object):
499
+ def __init__(self, bpe_path: str, context_length=77):
500
+ self.byte_encoder = bytes_to_unicode()
501
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
502
+
503
+ with g_pathmgr.open(bpe_path, "rb") as fh:
504
+ bpe_bytes = io.BytesIO(fh.read())
505
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
506
+ merges = merges[1 : 49152 - 256 - 2 + 1]
507
+ merges = [tuple(merge.split()) for merge in merges]
508
+ vocab = list(bytes_to_unicode().values())
509
+ vocab = vocab + [v + "</w>" for v in vocab]
510
+ for merge in merges:
511
+ vocab.append("".join(merge))
512
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
513
+ self.encoder = dict(zip(vocab, range(len(vocab))))
514
+ self.decoder = {v: k for k, v in self.encoder.items()}
515
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
516
+ self.cache = {
517
+ "<|startoftext|>": "<|startoftext|>",
518
+ "<|endoftext|>": "<|endoftext|>",
519
+ }
520
+ self.pat = re.compile(
521
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
522
+ re.IGNORECASE,
523
+ )
524
+ self.context_length = context_length
525
+
526
+ def bpe(self, token):
527
+ if token in self.cache:
528
+ return self.cache[token]
529
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
530
+ pairs = get_pairs(word)
531
+
532
+ if not pairs:
533
+ return token + "</w>"
534
+
535
+ while True:
536
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
537
+ if bigram not in self.bpe_ranks:
538
+ break
539
+ first, second = bigram
540
+ new_word = []
541
+ i = 0
542
+ while i < len(word):
543
+ try:
544
+ j = word.index(first, i)
545
+ new_word.extend(word[i:j])
546
+ i = j
547
+ except:
548
+ new_word.extend(word[i:])
549
+ break
550
+
551
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
552
+ new_word.append(first + second)
553
+ i += 2
554
+ else:
555
+ new_word.append(word[i])
556
+ i += 1
557
+ new_word = tuple(new_word)
558
+ word = new_word
559
+ if len(word) == 1:
560
+ break
561
+ else:
562
+ pairs = get_pairs(word)
563
+ word = " ".join(word)
564
+ self.cache[token] = word
565
+ return word
566
+
567
+ def encode(self, text):
568
+ bpe_tokens = []
569
+ text = whitespace_clean(basic_clean(text)).lower()
570
+ for token in re.findall(self.pat, text):
571
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
572
+ bpe_tokens.extend(
573
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
574
+ )
575
+ return bpe_tokens
576
+
577
+ def decode(self, tokens):
578
+ text = "".join([self.decoder[token] for token in tokens])
579
+ text = (
580
+ bytearray([self.byte_decoder[c] for c in text])
581
+ .decode("utf-8", errors="replace")
582
+ .replace("</w>", " ")
583
+ )
584
+ return text
585
+
586
+ def __call__(self, texts, context_length=None):
587
+ if not context_length:
588
+ context_length = self.context_length
589
+
590
+ if isinstance(texts, str):
591
+ texts = [texts]
592
+
593
+ sot_token = self.encoder["<|startoftext|>"]
594
+ eot_token = self.encoder["<|endoftext|>"]
595
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
596
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
597
+
598
+ for i, tokens in enumerate(all_tokens):
599
+ tokens = tokens[:context_length]
600
+ result[i, : len(tokens)] = torch.tensor(tokens)
601
+
602
+ if len(result) == 1:
603
+ return result[0]
604
+ return result
605
+
606
+
607
+ class IMUPreprocessor(VerboseNNModule):
608
+ def __init__(
609
+ self,
610
+ kernel_size: int,
611
+ imu_stem: PatchEmbedGeneric,
612
+ embed_dim: int,
613
+ img_size: List = (6, 2000),
614
+ num_cls_tokens: int = 1,
615
+ pos_embed_fn: Callable = None,
616
+ init_param_style: str = "openclip",
617
+ ) -> None:
618
+ super().__init__()
619
+ stem = imu_stem
620
+ self.imu_stem = imu_stem
621
+ self.embed_dim = embed_dim
622
+ self.use_pos_embed = pos_embed_fn is not None
623
+ self.num_cls_tokens = num_cls_tokens
624
+ self.kernel_size = kernel_size
625
+ self.pos_embed = nn.Parameter(
626
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
627
+ )
628
+
629
+ if self.num_cls_tokens > 0:
630
+ self.cls_token = nn.Parameter(
631
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
632
+ )
633
+
634
+ self.init_parameters(init_param_style)
635
+
636
+ @torch.no_grad()
637
+ def init_parameters(self, init_param_style):
638
+ nn.init.normal_(self.pos_embed, std=0.01)
639
+
640
+ if init_param_style == "openclip":
641
+ # OpenCLIP style initialization
642
+ scale = self.embed_dim**-0.5
643
+
644
+ if self.num_cls_tokens > 0:
645
+ nn.init.normal_(self.cls_token)
646
+ self.cls_token *= scale
647
+ elif init_param_style == "vit":
648
+ self.cls_token.data.fill_(0)
649
+ else:
650
+ raise ValueError(f"Unknown init {init_param_style}")
651
+
652
+ def tokenize_input_and_cls_pos(self, input, stem):
653
+ # tokens is of shape B x L x D
654
+ tokens = stem.norm_layer(stem.proj(input))
655
+ assert tokens.ndim == 3
656
+ assert tokens.shape[2] == self.embed_dim
657
+ B = tokens.shape[0]
658
+ if self.num_cls_tokens > 0:
659
+ class_tokens = self.cls_token.expand(
660
+ B, -1, -1
661
+ ) # stole class_tokens impl from Phil Wang, thanks
662
+ tokens = torch.cat((class_tokens, tokens), dim=1)
663
+ if self.use_pos_embed:
664
+ tokens = tokens + self.pos_embed
665
+ return tokens
666
+
667
+ def forward(self, imu):
668
+ # Patchify
669
+ imu = imu.unfold(
670
+ -1,
671
+ self.kernel_size,
672
+ self.kernel_size,
673
+ ).permute(0, 2, 1, 3)
674
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
675
+
676
+ imu_tokens = self.tokenize_input_and_cls_pos(
677
+ imu,
678
+ self.imu_stem,
679
+ )
680
+
681
+ return_dict = {
682
+ "trunk": {
683
+ "tokens": imu_tokens,
684
+ },
685
+ "head": {},
686
+ }
687
+ return return_dict
models/transformer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # Code modified from
9
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
10
+ # https://github.com/facebookresearch/deit/blob/main/models.py
11
+ # and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
12
+
13
+
14
+ import copy
15
+ import fnmatch
16
+ import logging
17
+ from functools import partial
18
+ from typing import Callable, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils.checkpoint as checkpoint
23
+
24
+ from timm.models.layers import DropPath, trunc_normal_
25
+
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim,
31
+ num_heads=8,
32
+ qkv_bias=False,
33
+ qk_scale=None,
34
+ attn_drop=0.0,
35
+ proj_drop=0.0,
36
+ ):
37
+ super().__init__()
38
+ self.num_heads = num_heads
39
+ head_dim = dim // num_heads
40
+ # NOTE scale factor was wrong in my original version,
41
+ # can set manually to be compat with prev weights
42
+ self.scale = qk_scale or head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x):
50
+ B, N, C = x.shape
51
+ qkv = (
52
+ self.qkv(x)
53
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
54
+ .permute(2, 0, 3, 1, 4)
55
+ )
56
+ q, k, v = (
57
+ qkv[0],
58
+ qkv[1],
59
+ qkv[2],
60
+ ) # make torchscript happy (cannot use tensor as tuple)
61
+
62
+ attn = (q @ k.transpose(-2, -1)) * self.scale
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(
74
+ self,
75
+ in_features,
76
+ hidden_features=None,
77
+ out_features=None,
78
+ act_layer=nn.GELU,
79
+ drop=0.0,
80
+ ):
81
+ super().__init__()
82
+ out_features = out_features or in_features
83
+ hidden_features = hidden_features or in_features
84
+ self.fc1 = nn.Linear(in_features, hidden_features)
85
+ self.act = act_layer()
86
+ self.fc2 = nn.Linear(hidden_features, out_features)
87
+ self.drop = nn.Dropout(drop)
88
+
89
+ def forward(self, x):
90
+ x = self.fc1(x)
91
+ x = self.act(x)
92
+ x = self.drop(x)
93
+ x = self.fc2(x)
94
+ x = self.drop(x)
95
+ return x
96
+
97
+
98
+ class MultiheadAttention(nn.MultiheadAttention):
99
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
100
+ return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
101
+
102
+
103
+ class ViTAttention(Attention):
104
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
105
+ assert attn_mask is None
106
+ return super().forward(x)
107
+
108
+
109
+ class BlockWithMasking(nn.Module):
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ attn_target: Callable,
114
+ mlp_ratio: int = 4,
115
+ act_layer: Callable = nn.GELU,
116
+ norm_layer: Callable = nn.LayerNorm,
117
+ ffn_dropout_rate: float = 0.0,
118
+ drop_path: float = 0.0,
119
+ layer_scale_type: str = None,
120
+ layer_scale_init_value: float = 1e-4,
121
+ ):
122
+ super().__init__()
123
+
124
+ assert not isinstance(
125
+ attn_target, nn.Module
126
+ ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
127
+ self.attn = attn_target()
128
+ if drop_path > 0.0:
129
+ self.drop_path = DropPath(drop_path)
130
+ else:
131
+ self.drop_path = nn.Identity()
132
+ self.norm_1 = norm_layer(dim)
133
+ mlp_hidden_dim = int(mlp_ratio * dim)
134
+ self.mlp = Mlp(
135
+ in_features=dim,
136
+ hidden_features=mlp_hidden_dim,
137
+ act_layer=act_layer,
138
+ drop=ffn_dropout_rate,
139
+ )
140
+ self.norm_2 = norm_layer(dim)
141
+ self.layer_scale_type = layer_scale_type
142
+ if self.layer_scale_type is not None:
143
+ assert self.layer_scale_type in [
144
+ "per_channel",
145
+ "scalar",
146
+ ], f"Found Layer scale type {self.layer_scale_type}"
147
+ if self.layer_scale_type == "per_channel":
148
+ # one gamma value per channel
149
+ gamma_shape = [1, 1, dim]
150
+ elif self.layer_scale_type == "scalar":
151
+ # single gamma value for all channels
152
+ gamma_shape = [1, 1, 1]
153
+ # two gammas: for each part of the fwd in the encoder
154
+ self.layer_scale_gamma1 = nn.Parameter(
155
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
156
+ requires_grad=True,
157
+ )
158
+ self.layer_scale_gamma2 = nn.Parameter(
159
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
160
+ requires_grad=True,
161
+ )
162
+
163
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
164
+ if self.layer_scale_type is None:
165
+ x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
166
+ x = x + self.drop_path(self.mlp(self.norm_2(x)))
167
+ else:
168
+ x = (
169
+ x
170
+ + self.drop_path(self.attn(self.norm_1(x), attn_mask))
171
+ * self.layer_scale_gamma1
172
+ )
173
+ x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
174
+ return x
175
+
176
+
177
+ _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
178
+
179
+
180
+ class SimpleTransformer(nn.Module):
181
+ def __init__(
182
+ self,
183
+ attn_target: Callable,
184
+ embed_dim: int,
185
+ num_blocks: int,
186
+ block: Callable = BlockWithMasking,
187
+ pre_transformer_layer: Callable = None,
188
+ post_transformer_layer: Callable = None,
189
+ drop_path_rate: float = 0.0,
190
+ drop_path_type: str = "progressive",
191
+ norm_layer: Callable = _LAYER_NORM,
192
+ mlp_ratio: int = 4,
193
+ ffn_dropout_rate: float = 0.0,
194
+ layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
195
+ layer_scale_init_value: float = 1e-4, # from cait; float
196
+ weight_init_style: str = "jax", # possible values jax or pytorch
197
+ ):
198
+ """
199
+ Simple Transformer with the following features
200
+ 1. Supports masked attention
201
+ 2. Supports DropPath
202
+ 3. Supports LayerScale
203
+ 4. Supports Dropout in Attention and FFN
204
+ 5. Makes few assumptions about the input except that it is a Tensor
205
+ """
206
+ super().__init__()
207
+ self.pre_transformer_layer = pre_transformer_layer
208
+ if drop_path_type == "progressive":
209
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
210
+ elif drop_path_type == "uniform":
211
+ dpr = [drop_path_rate for i in range(num_blocks)]
212
+ else:
213
+ raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
214
+
215
+ self.blocks = nn.Sequential(
216
+ *[
217
+ block(
218
+ dim=embed_dim,
219
+ attn_target=attn_target,
220
+ mlp_ratio=mlp_ratio,
221
+ ffn_dropout_rate=ffn_dropout_rate,
222
+ drop_path=dpr[i],
223
+ norm_layer=norm_layer,
224
+ layer_scale_type=layer_scale_type,
225
+ layer_scale_init_value=layer_scale_init_value,
226
+ )
227
+ for i in range(num_blocks)
228
+ ]
229
+ )
230
+ self.post_transformer_layer = post_transformer_layer
231
+ self.weight_init_style = weight_init_style
232
+ self.apply(self._init_weights)
233
+
234
+ def _init_weights(self, m):
235
+ if isinstance(m, nn.Linear):
236
+ if self.weight_init_style == "jax":
237
+ # Based on MAE and official Jax ViT implementation
238
+ torch.nn.init.xavier_uniform_(m.weight)
239
+ elif self.weight_init_style == "pytorch":
240
+ # PyTorch ViT uses trunc_normal_
241
+ trunc_normal_(m.weight, std=0.02)
242
+
243
+ if m.bias is not None:
244
+ nn.init.constant_(m.bias, 0)
245
+ elif isinstance(m, (nn.LayerNorm)):
246
+ nn.init.constant_(m.bias, 0)
247
+ nn.init.constant_(m.weight, 1.0)
248
+
249
+ def forward(
250
+ self,
251
+ tokens: torch.Tensor,
252
+ attn_mask: torch.Tensor = None,
253
+ use_checkpoint: bool = False,
254
+ checkpoint_every_n: int = 1,
255
+ checkpoint_blk_ids: List[int] = None,
256
+ ):
257
+ """
258
+ Inputs
259
+ - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
260
+ - attn: mask of shape L x L
261
+
262
+ Output
263
+ - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
264
+ """
265
+ if self.pre_transformer_layer:
266
+ tokens = self.pre_transformer_layer(tokens)
267
+ if use_checkpoint and checkpoint_blk_ids is None:
268
+ checkpoint_blk_ids = [
269
+ blk_id
270
+ for blk_id in range(len(self.blocks))
271
+ if blk_id % checkpoint_every_n == 0
272
+ ]
273
+ if checkpoint_blk_ids:
274
+ checkpoint_blk_ids = set(checkpoint_blk_ids)
275
+ for blk_id, blk in enumerate(self.blocks):
276
+ if use_checkpoint and blk_id in checkpoint_blk_ids:
277
+ tokens = checkpoint.checkpoint(
278
+ blk, tokens, attn_mask, use_reentrant=False
279
+ )
280
+ else:
281
+ tokens = blk(tokens, attn_mask=attn_mask)
282
+ if self.post_transformer_layer:
283
+ tokens = self.post_transformer_layer(tokens)
284
+ return tokens
train.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import data
2
+ import torch
3
+ from models import imagebind_model
4
+ from models.imagebind_model import ModalityType
5
+ import torch.nn as nn
6
+ from imagen_pytorch import ImagenTrainer
7
+ from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
8
+ from extract.getim import load_image
9
+ import torch.optim as optim
10
+ import os
11
+ from torchvision import transforms
12
+ from image2vidimg import cobtwoten, cobtwoten256
13
+ import os
14
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
15
+ device = torch.device("cuda")
16
+ #import matplotlib.pyplot as plt
17
+ #import torch.nn.functional as F
18
+ #import cv2
19
+
20
+ torch.cuda.empty_cache()
21
+ transform = transforms.Compose([
22
+ transforms.ToTensor(), # 将numpy数组或PIL.Image读的图片转换成(C,H, W)的Tensor格式且/255归一化到[0,1.0]之间
23
+ ]) # 来自ImageNet的mean和variance
24
+ unloader = transforms.ToPILImage()
25
+
26
+ # def imshow(tensor, title=None):
27
+ #
28
+ # tensor=tensor.permute(1,2,0)
29
+ # print(tensor.shape)
30
+ # cv2.imshow('image:', tensor.cpu().numpy())
31
+ # # 防止图片关闭
32
+ # cv2.waitKey(0)
33
+ # # plt.imshow(img_pil)
34
+ # # if title is not None:
35
+ # # plt.title(title)
36
+ # # plt.pause(0.001) # pause a bit so that plots are updated
37
+
38
+ def imagebind_out(audio_paths,model):
39
+ # Load data
40
+ inputs = {
41
+ ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
42
+ }
43
+
44
+ with torch.no_grad():
45
+ embeddings = model(inputs)
46
+
47
+ return embeddings
48
+
49
+ class encode_audio(nn.Module):
50
+ def __init__(self):
51
+ super().__init__()
52
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
+ # self.link1=nn.Linear(1024,768)
54
+ self.link2=nn.Linear(1024,343)
55
+ # self.link3=nn.Linear(1024,768)
56
+
57
+ def forward(self,embeddings):
58
+ l1=embeddings
59
+ l2=self.link2(embeddings)
60
+ # l3=self.link3(embeddings)
61
+ l3=torch.matmul(l2.transpose(1,2),l1)
62
+
63
+ return torch.cat([l1,l3],dim=1)
64
+
65
+
66
+ # os.listdir()方法获取文件夹名字,返回数组
67
+ def getAllFiles(targetDir):
68
+ listFiles = os.listdir(targetDir)
69
+ return listFiles
70
+
71
+ # unet for imagen
72
+ unet1 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 64, dim_mults = (1, 2, 4, 8)).to(device)
73
+ unet2 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 128, dim_mults = (1, 2, 4, 8)).to(device)
74
+ # unet3 = Unet3D(dim = 256, dim_mults = (1, 2, 4, 8)).cuda()
75
+ #unet1 = NullUnet() # add a placeholder "null" unet for the base unet
76
+
77
+ imagen = ElucidatedImagen(
78
+ text_embed_dim=1024,
79
+ unets = (unet1, unet2),
80
+ image_sizes = (64, 128),
81
+ random_crop_sizes = (None, 64),
82
+ temporal_downsample_factor = (2, 1), # in this example, the first unet would receive the video temporally downsampled by 2x
83
+ num_sample_steps = 10,
84
+ cond_drop_prob = 0.1,
85
+ sigma_min = 0.002, # min noise level
86
+ sigma_max = (80, 160), # max noise level, double the max noise level for upsampler
87
+ sigma_data = 0.5, # standard deviation of data distribution
88
+ rho = 7, # controls the sampling schedule
89
+ P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
90
+ P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
91
+ S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
92
+ S_tmin = 0.05,
93
+ S_tmax = 50,
94
+ S_noise = 1.003,
95
+ ).to(device)
96
+
97
+ trainer = ImagenTrainer(imagen)
98
+ # trainer.to(device)
99
+ # trainer.load("./checkpoint.pt")
100
+ trainer = trainer.to(device)
101
+ # Instantiate model
102
+
103
+ # device_ids = [0, 1]
104
+ model_imageb = imagebind_model.imagebind_huge(pretrained=True)
105
+ model_imageb=model_imageb.to(device)
106
+ model_imageb.eval()
107
+ # model_imageb=model_imageb.cuda(device=device_ids)
108
+ # model_imageb.to(device)
109
+
110
+ epo=31
111
+ p=1
112
+ files = getAllFiles("./extract/audio")
113
+
114
+ outloss=0
115
+ model1=(encode_audio()).to(device)
116
+ # model1.load_state_dict(torch.load("wlc.pt").state_dict())
117
+ optimizer = optim.Adam(model1.parameters(), lr=1e-5,
118
+ betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)
119
+ # model1.eval()
120
+ model1.train()
121
+ torch.cuda.empty_cache()
122
+
123
+ for k in range(epo):
124
+ for nm in range(0, len(files) + 1 - p, p):
125
+ #for i in (1, 2):
126
+ file_ext0 = os.path.splitext(files[nm])
127
+ front0, ext0 = file_ext0
128
+ audio_pat=[]
129
+ audio_pat.append("./extract/audio/" + str(front0) + ".wav")
130
+ # fcontents = load_image("./extract/image/0.jpg", transform=None, shape=[256, 128])
131
+ fcontent = cobtwoten("./extract/image/" + str(front0) + ".jpg")
132
+ # print(fcontent.shape)
133
+ #fcontent = load_image("./extract/image/" + str(front0) + ".jpg", transform, shape=[256, 256])
134
+ for ni in range(1,p):
135
+ file_ext = os.path.splitext(files[nm+ni])
136
+ front, ext = file_ext
137
+ # content = load_image("./extract/image/" + str(front) + ".jpg", transform, shape=[256, 256])
138
+ content = cobtwoten("./extract/image/" + str(front) + ".jpg")
139
+ fcontent = torch.cat((fcontent, content), -5)
140
+ audio_pat.append("./extract/audio/" + str(front) + ".wav")
141
+ # imageb=torch.LongTensor(imageb_out["audio"])
142
+
143
+ imageb_out = imagebind_out(audio_pat,model_imageb)
144
+ fmusic = model1(imageb_out["audio"].unsqueeze(1))#(5,1,1024)->(5,344,1024)
145
+ # fmusic = model1(imageb_out["audio"].unsqueeze(1).cuda())#(5,1,1024)->(5,344,1024)
146
+ # print(fmusic)
147
+ # print(fmusic.shape)
148
+ fmusic=fmusic.to(device)
149
+ fcontent=fcontent.to(device)
150
+ loss = trainer(fcontent, text_embeds=fmusic, unet_number = 2,ignore_time = False, max_batch_size = p)
151
+ trainer.update(unet_number = 2)
152
+ optimizer.step()
153
+ # print(optimizer.state)
154
+ optimizer.zero_grad()
155
+ print(loss)
156
+ outloss=outloss+loss
157
+ #print("unet"+str(i)+" "+str(loss))
158
+
159
+ outloss=outloss
160
+
161
+ print("epoch"+str(k)+" "+" loss: "+str(outloss))
162
+
163
+ outloss=0
164
+ if k % 3 == 2:
165
+ torch.save(model1, "wlc.pt")
166
+ trainer.save('./checkpoint.pt')
167
+
168
+
169
+
170
+
171
+
172
+ # text_list=["A dog.", "A car", "A bird"]
173
+ # image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
174
+
175
+
176
+ #
177
+ #
178
+ #
179
+ #
180
+ # print(
181
+ # "Vision x Text: ",
182
+ # torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
183
+ # )
184
+ # print(
185
+ # "Audio x Text: ",
186
+ # torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
187
+ # )
188
+ # print(
189
+ # "Vision x Audio: ",
190
+ # torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
191
+ # )
192
+ #
193
+ #
194
+ #
195
+ # print(embeddings['audio'].shape)
196
+ # print(embeddings[ModalityType.AUDIO].shape)
197
+ # print(embeddings[ModalityType.VISION].shape)
198
+
199
+
200
+
201
+ # Expected output:
202
+ #
203
+ # Vision x Text:
204
+ # tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
205
+ # [3.3836e-05, 9.9994e-01, 2.4118e-05],
206
+ # [4.7997e-05, 1.3496e-02, 9.8646e-01]])
207
+ #
208
+ # Audio x Text:
209
+ # tensor([[1., 0., 0.],
210
+ # [0., 1., 0.],
211
+ # [0., 0., 1.]])
212
+ #
213
+ # Vision x Audio:
214
+ # tensor([[0.8070, 0.1088, 0.0842],
215
+ # [0.1036, 0.7884, 0.1079],
216
+ # [0.0018, 0.0022, 0.9960]])