Lin Z commited on
Commit
d6d7648
·
1 Parent(s): 0dc0933

init commit

Browse files
Files changed (42) hide show
  1. .checkpoints/imagebind_huge.pth +3 -0
  2. app.py +140 -0
  3. assets/.DS_Store +0 -0
  4. assets/lion_and_gun.png +0 -0
  5. assets/lions_roaring.wav +0 -0
  6. assets/machine_gun_shooting.wav +0 -0
  7. audio_encoder.py +124 -0
  8. checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/audio_encoder/config.json +6 -0
  9. checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/audio_encoder/diffusion_pytorch_model.safetensors +3 -0
  10. checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/unet/config.json +61 -0
  11. checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/unet/diffusion_pytorch_model.safetensors +3 -0
  12. datasets/AVSync15/class_clip_text_encodings_stable-diffusion-v1-5.pt +3 -0
  13. ff_spatio_audio_temp_transformer_3d.py +374 -0
  14. ff_spatio_temp_resnet_3d.py +191 -0
  15. ff_spatio_temp_transformer_3d.py +331 -0
  16. imagebind/__init__.py +3 -0
  17. imagebind/__pycache__/__init__.cpython-310.pyc +0 -0
  18. imagebind/__pycache__/data.cpython-310.pyc +0 -0
  19. imagebind/bpe/bpe_simple_vocab_16e6.txt.gz +3 -0
  20. imagebind/data.py +343 -0
  21. imagebind/models/__init__.py +0 -0
  22. imagebind/models/__pycache__/__init__.cpython-310.pyc +0 -0
  23. imagebind/models/__pycache__/helpers.cpython-310.pyc +0 -0
  24. imagebind/models/__pycache__/imagebind_model.cpython-310.pyc +0 -0
  25. imagebind/models/__pycache__/multimodal_preprocessors.cpython-310.pyc +0 -0
  26. imagebind/models/__pycache__/transformer.cpython-310.pyc +0 -0
  27. imagebind/models/helpers.py +140 -0
  28. imagebind/models/imagebind_model.py +506 -0
  29. imagebind/models/multimodal_preprocessors.py +685 -0
  30. imagebind/models/transformer.py +280 -0
  31. pipeline.py +602 -0
  32. pretrained/openai-clip-l_null_text_encoding.pt +3 -0
  33. pretrained/stable-diffusion-v1-5/scheduler/scheduler_config.json +13 -0
  34. pretrained/stable-diffusion-v1-5/vae/config.json +29 -0
  35. pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin +3 -0
  36. pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.fp16.bin +3 -0
  37. pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.fp16.safetensors +3 -0
  38. pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.safetensors +3 -0
  39. requirements.txt +11 -0
  40. unet.py +839 -0
  41. unet_blocks.py +1084 -0
  42. unet_utils.py +163 -0
.checkpoints/imagebind_huge.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6f6c22bedcc90708448d5d2fbb7b2db9c73f505dc89bd0b2e09b23af1b62157
3
+ size 4803584173
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
+ import gradio as gr
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from diffusers.models import AutoencoderKL
9
+ from diffusers.schedulers import PNDMScheduler
10
+ from unet import AudioUNet3DConditionModel
11
+ from audio_encoder import ImageBindSegmaskAudioEncoder
12
+ from pipeline import AudioCondAnimationPipeline, generate_videos
13
+
14
+
15
+ device = torch.device("cuda")
16
+ dtype = torch.float16
17
+
18
+
19
+ def freeze_and_make_eval(model: nn.Module):
20
+ for param in model.parameters():
21
+ param.requires_grad = False
22
+ model.eval()
23
+
24
+
25
+ def create_pipeline(device=torch.device("cuda"), dtype=torch.float32):
26
+ # 2. Prepare model
27
+ pretrained_stable_diffusion_path = "./pretrained/stable-diffusion-v1-5"
28
+
29
+ checkpoint_path = f"checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules"
30
+ category_text_encoding_mapping = torch.load('datasets/AVSync15/class_clip_text_encodings_stable-diffusion-v1-5.pt', map_location="cpu")
31
+
32
+ scheduler = PNDMScheduler.from_pretrained(pretrained_stable_diffusion_path, subfolder="scheduler")
33
+ vae = AutoencoderKL.from_pretrained(pretrained_stable_diffusion_path, subfolder="vae").to(device=device, dtype=dtype)
34
+ audio_encoder = ImageBindSegmaskAudioEncoder(n_segment=12).to(device=device, dtype=dtype)
35
+ freeze_and_make_eval(audio_encoder)
36
+ unet = AudioUNet3DConditionModel.from_pretrained(checkpoint_path, subfolder="unet").to(device=device, dtype=dtype)
37
+
38
+ pipeline = AudioCondAnimationPipeline(
39
+ unet=unet,
40
+ scheduler=scheduler,
41
+ vae=vae,
42
+ audio_encoder=audio_encoder,
43
+ null_text_encodings_path="./pretrained/openai-clip-l_null_text_encoding.pt"
44
+ )
45
+ pipeline.to(torch_device=device, dtype=dtype)
46
+ pipeline.set_progress_bar_config(disable=True)
47
+
48
+ return pipeline, category_text_encoding_mapping
49
+
50
+ pipeline, category_text_encoding_mapping = create_pipeline(device, dtype)
51
+
52
+
53
+ def generate_video(image, audio, text, audio_guidance_scale, denoising_step):
54
+
55
+ category_text_encoding = category_text_encoding_mapping[text].view(1, 77, 768)
56
+
57
+ generate_videos(
58
+ pipeline,
59
+ audio_path=audio,
60
+ image_path=image,
61
+ category_text_encoding=category_text_encoding,
62
+ image_size=(256, 256),
63
+ video_fps=6,
64
+ video_num_frame=12,
65
+ text_guidance_scale=1.0,
66
+ audio_guidance_scale=audio_guidance_scale,
67
+ denoising_step=denoising_step,
68
+ seed=123,
69
+ save_path="./output_video.mp4",
70
+ device=device
71
+ )
72
+
73
+ return "./output_video.mp4"
74
+
75
+
76
+ if __name__ == "__main__":
77
+
78
+ categories = [
79
+ "baby babbling crying", "dog barking", "hammering", "striking bowling", "cap gun shooting",
80
+ "chicken crowing", "frog croaking", "lions roaring", "machine gun shooting", "playing cello",
81
+ "playing trombone", "playing trumpet", "playing violin fiddle", "sharpen knife", "toilet flushing"
82
+ ]
83
+
84
+ title = ""
85
+ description = """
86
+ <div align="center">
87
+
88
+ <h1 style="font-size: 60px;">Audio-Synchronized Visual Animation</h1>
89
+
90
+ <p style="font-size: 30px;">
91
+ <a href="https://lzhangbj.github.io/projects/asva/asva.html">Project Webpage</a>
92
+ </p>
93
+
94
+ <p style="font-size: 30px;">
95
+ <a href="https://lzhangbj.github.io/">Lin Zhang</a>,
96
+ <a href="https://scholar.google.com/citations?user=6aYncPAAAAAJ">Shentong Mo</a>,
97
+ <a href="https://yijingz02.github.io/">Yijing Zhang</a>,
98
+ <a href="https://pedro-morgado.github.io/">Pedro Morgado</a>
99
+ </p>
100
+
101
+ <p style="font-size: 30px;">
102
+ University of Wisconsin Madison,
103
+ Carnegie Mellon University
104
+ <p>
105
+
106
+ <strong style="font-size: 30px;">ECCV 2024</strong>
107
+
108
+ <strong style="font-size: 25px;">Animate your images with audio-synchronized motion! </strong>
109
+
110
+ <p style="font-size: 18px;">Notes:</p>
111
+ <p style="font-size: 18px;">(1) Only the first 2 seconds of audio is used. </p>
112
+ <p style="font-size: 18px;">(2) Increase audio guidance scale for amplified visual dynamics. </p>
113
+ <p style="font-size: 18px;">(3) Increase sampling steps for higher visual quality. </p>
114
+
115
+ </div>
116
+ """
117
+
118
+ # <p style="font-size: 20px;">Please be patient. Due to limited resources on huggingface, the generation may take up to 10mins </p>
119
+
120
+ # Gradio Interface
121
+ iface = gr.Interface(
122
+ fn=generate_video,
123
+ inputs=[
124
+ gr.Image( label="Upload Image", type="filepath", height=256),
125
+ gr.Audio(label="Upload Audio", type="filepath"),
126
+ gr.Dropdown(choices=categories, label="Select Audio Category"),
127
+ gr.Slider(minimum=1.0, maximum=12.0, step=0.1, value=4.0, label="Audio Guidance Scale"),
128
+ gr.Slider(minimum=1, maximum=50, step=1, value=20, label="Sampling steps")
129
+ ],
130
+ outputs=gr.Video(label="Generated Video", height=256),
131
+ title=title,
132
+ description=description,
133
+ examples = [
134
+ ["./assets/lion_and_gun.png", "./assets/lions_roaring.wav", "lions roaring", 4.0, 20],
135
+ ["./assets/lion_and_gun.png", "./assets/machine_gun_shooting.wav", "machine gun shooting", 4.0, 20],
136
+ ]
137
+ )
138
+
139
+ # Launch the interface
140
+ iface.launch()
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/lion_and_gun.png ADDED
assets/lions_roaring.wav ADDED
Binary file (135 kB). View file
 
assets/machine_gun_shooting.wav ADDED
Binary file (885 kB). View file
 
audio_encoder.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, TypeVar, Tuple, Any
4
+
5
+ T = TypeVar('T', bound='Module')
6
+ from einops import rearrange, repeat
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from transformers.utils import ModelOutput
13
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
14
+
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+
18
+ from imagebind.models import imagebind_model
19
+ from imagebind.models.imagebind_model import ModalityType
20
+
21
+
22
+ @dataclass
23
+ class ImageBindSegmaskAudioEncoderOutput(ModelOutput):
24
+ """
25
+ Args:
26
+ text_embeds(`torch.Tensor` of shape `(batch_size, output_dim`):
27
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
28
+ image_embeds(`torch.Tensor` of shape `(batch_size, output_dim`):
29
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
30
+ text_model_output(`BaseModelOutputWithPooling`):
31
+ The output of the [`CLIPTextModel`].
32
+ vision_model_output(`BaseModelOutputWithPooling`):
33
+ The output of the [`CLIPVisionModel`].
34
+ """
35
+ audio_embeds: torch.Tensor = None
36
+ audio_encodings: torch.Tensor = None
37
+ audio_segment_masks: torch.BoolTensor = None
38
+
39
+ def to_tuple(self) -> Tuple[Any]:
40
+ return tuple(self[k] for k in self.keys())
41
+
42
+
43
+ class ImageBindSegmaskAudioEncoder(ModelMixin, ConfigMixin):
44
+
45
+ @register_to_config
46
+ def __init__(self,
47
+ n_segment=4,
48
+ pretrained_model_name="imagebind-huge"
49
+ ):
50
+ super().__init__()
51
+ self.n_segment = n_segment
52
+
53
+ self.pretrained_model_name = pretrained_model_name
54
+ if pretrained_model_name == "imagebind-huge":
55
+ pretrained_model = imagebind_model.imagebind_huge(pretrained=True)
56
+
57
+ self.preprocessor = pretrained_model.modality_preprocessors[ModalityType.AUDIO]
58
+ self.trunk = pretrained_model.modality_trunks[ModalityType.AUDIO]
59
+ self.head = pretrained_model.modality_heads[ModalityType.AUDIO]
60
+ self.postprocessor = pretrained_model.modality_postprocessors[ModalityType.AUDIO]
61
+ self.final_layer_norm = nn.LayerNorm(normalized_shape=768, eps=1e-6)
62
+
63
+ def _auto_split(self, n, n_chunk):
64
+ '''
65
+ automatically split into chunks with n_ele no differ by 1
66
+ if n is not dividible by n_chunk, extra one's will be added to the heading chunks
67
+ '''
68
+ chunk_size = int(math.ceil(n / n_chunk))
69
+ assert chunk_size >= 1, chunk_size
70
+
71
+ chunk_start_indices = np.round(np.linspace(0, n - chunk_size, n_chunk, endpoint=True)).astype(np.int32)
72
+
73
+ mask = torch.zeros(n_chunk, n).bool()
74
+ for chunk_index, chunk_start_index in enumerate(chunk_start_indices):
75
+ mask[chunk_index, chunk_start_index:chunk_start_index + chunk_size] = 1
76
+ mask = mask.contiguous()
77
+ assert mask.long().sum() == chunk_size * n_chunk, mask.long().sum()
78
+
79
+ return mask
80
+
81
+ def forward(self,
82
+ input_features: Optional[torch.Tensor],
83
+ normalize: bool = False,
84
+ return_dict: Optional[bool] = None):
85
+
86
+ n_segment = self.n_segment
87
+
88
+ # 1. reshape to imagebind input
89
+ batchsize = input_features.size(0)
90
+
91
+ # 2. patchify images and add positional embedding and
92
+ audio_inputs = self.preprocessor(input_features)
93
+ trunk_inputs = audio_inputs["trunk"] # dict of {"tokens": (b, l, d)}
94
+
95
+ # 3. get audio encoder output
96
+ audio_encodings = self.trunk(**trunk_inputs) # w/o layer norm (b, seq_len, c)
97
+ head_inputs = audio_inputs["head"]
98
+ cls_embeds = self.head(audio_encodings, **head_inputs)
99
+ # normalize and logit scaling
100
+ if normalize:
101
+ cls_embeds = self.postprocessor(cls_embeds) # (b, c)
102
+ audio_encodings = self.final_layer_norm(audio_encodings)
103
+
104
+ # 4. get segment masks
105
+ n, t = 12, 19 # hard code
106
+ segment_mask = self._auto_split(t, n_segment).unsqueeze(1).expand(n_segment, n, t).contiguous() # (s, n, t)
107
+ segment_mask = rearrange(
108
+ segment_mask, "s n t -> s (n t)"
109
+ )
110
+ segment_mask = torch.cat([
111
+ torch.ones(n_segment, 1).bool(),
112
+ segment_mask
113
+ ], dim=1) # (s, 1+n*t)
114
+
115
+ segment_masks = repeat(segment_mask, "n s -> b n s", b=batchsize).contiguous().bool().to(self.device)
116
+
117
+ if not return_dict:
118
+ return cls_embeds, audio_encodings, segment_masks
119
+
120
+ return ImageBindSegmaskAudioEncoderOutput(
121
+ audio_embeds=cls_embeds,
122
+ audio_encodings=audio_encodings,
123
+ audio_segment_masks=segment_masks
124
+ )
checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/audio_encoder/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ImageBindSegmaskAudioEncoder",
3
+ "_diffusers_version": "0.29.2",
4
+ "n_segment": 12,
5
+ "pretrained_model_name": "imagebind-huge"
6
+ }
checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/audio_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93622a01c9bdd6bad87530617f0fdc772be958dc435b3303ed66ba938311aa4b
3
+ size 172492226
checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/unet/config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AudioUNet3DConditionModel",
3
+ "_diffusers_version": "0.29.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "attention_head_dim": 8,
8
+ "audio_cross_attention_dim": 768,
9
+ "block_out_channels": [
10
+ 320,
11
+ 640,
12
+ 1280,
13
+ 1280
14
+ ],
15
+ "center_input_sample": false,
16
+ "class_embed_type": null,
17
+ "class_embeddings_concat": false,
18
+ "conv_in_kernel": 3,
19
+ "conv_out_kernel": 3,
20
+ "cross_attention_dim": 768,
21
+ "cross_attention_norm": null,
22
+ "down_block_types": [
23
+ "FFSpatioAudioTempCrossAttnDownBlock3D",
24
+ "FFSpatioAudioTempCrossAttnDownBlock3D",
25
+ "FFSpatioAudioTempCrossAttnDownBlock3D",
26
+ "FFSpatioTempResDownBlock3D"
27
+ ],
28
+ "downsample_padding": 1,
29
+ "dual_cross_attention": false,
30
+ "encoder_hid_dim": null,
31
+ "flip_sin_to_cos": true,
32
+ "freq_shift": 0,
33
+ "in_channels": 4,
34
+ "layers_per_block": 2,
35
+ "mid_block_only_cross_attention": null,
36
+ "mid_block_scale_factor": 1,
37
+ "mid_block_type": "FFSpatioAudioTempCrossAttnUNetMidBlock3D",
38
+ "norm_eps": 1e-05,
39
+ "norm_num_groups": 32,
40
+ "num_class_embeds": null,
41
+ "only_cross_attention": false,
42
+ "out_channels": 4,
43
+ "projection_class_embeddings_input_dim": null,
44
+ "resnet_out_scale_factor": 1.0,
45
+ "resnet_skip_time_act": false,
46
+ "resnet_time_scale_shift": "default",
47
+ "sample_size": 64,
48
+ "time_cond_proj_dim": null,
49
+ "time_embedding_act_fn": null,
50
+ "time_embedding_dim": null,
51
+ "time_embedding_type": "positional",
52
+ "timestep_post_act": null,
53
+ "up_block_types": [
54
+ "FFSpatioTempResUpBlock3D",
55
+ "FFSpatioAudioTempCrossAttnUpBlock3D",
56
+ "FFSpatioAudioTempCrossAttnUpBlock3D",
57
+ "FFSpatioAudioTempCrossAttnUpBlock3D"
58
+ ],
59
+ "upcast_attention": false,
60
+ "use_linear_projection": false
61
+ }
checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:234652f6029bd49d05d6e77e5fe6721e239bbb4ae93a60112ea53d95824da097
3
+ size 4677570888
datasets/AVSync15/class_clip_text_encodings_stable-diffusion-v1-5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10b3e0bcf2f12ee7c0410165e2872ae76fe3a58f9d43834781cc8bd79c5cfc46
3
+ size 3553440
ff_spatio_audio_temp_transformer_3d.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+ from einops import rearrange
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import Attention
15
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero
16
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
17
+
18
+ from unet_utils import FFAttention
19
+
20
+
21
+ @dataclass
22
+ class SpatioTempTransformer3DModelOutput(BaseOutput):
23
+ sample: torch.Tensor
24
+
25
+
26
+ if is_xformers_available():
27
+ import xformers
28
+ import xformers.ops
29
+ else:
30
+ xformers = None
31
+
32
+
33
+ class FFSpatioAudioTempTransformer3DModel(ModelMixin, ConfigMixin):
34
+
35
+ @register_to_config
36
+ def __init__(
37
+ self,
38
+ num_attention_heads: int = 16,
39
+ attention_head_dim: int = 88,
40
+ in_channels: Optional[int] = None,
41
+ num_layers: int = 1,
42
+ dropout: float = 0.0,
43
+ norm_num_groups: int = 32,
44
+ cross_attention_dim: Optional[int] = None,
45
+ audio_cross_attention_dim: Optional[int] = None,
46
+ attention_bias: bool = False,
47
+ activation_fn: str = "geglu",
48
+ num_embeds_ada_norm: Optional[int] = None,
49
+ use_linear_projection: bool = False,
50
+ only_cross_attention: bool = False,
51
+ upcast_attention: bool = False,
52
+ ):
53
+ super().__init__()
54
+ self.use_linear_projection = use_linear_projection
55
+ self.num_attention_heads = num_attention_heads
56
+ self.attention_head_dim = attention_head_dim
57
+ inner_dim = num_attention_heads * attention_head_dim
58
+
59
+ # Define input layers
60
+ self.in_channels = in_channels
61
+
62
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
63
+ if use_linear_projection:
64
+ self.proj_in = nn.Linear(in_channels, inner_dim)
65
+ else:
66
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ BasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ audio_cross_attention_dim=audio_cross_attention_dim,
78
+ activation_fn=activation_fn,
79
+ num_embeds_ada_norm=num_embeds_ada_norm,
80
+ attention_bias=attention_bias,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ )
84
+ for d in range(num_layers)
85
+ ]
86
+ )
87
+
88
+ # 4. Define output layers
89
+ if use_linear_projection:
90
+ self.proj_out = nn.Linear(in_channels, inner_dim)
91
+ else:
92
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states,
97
+ encoder_hidden_states=None,
98
+ audio_encoder_hidden_states=None,
99
+ audio_attention_mask=None,
100
+ timestep=None,
101
+ class_labels=None,
102
+ cross_attention_kwargs=None,
103
+ return_dict: bool = True
104
+ ):
105
+ # Input
106
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
107
+ video_length = hidden_states.shape[2]
108
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
109
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b f n c -> (b f) n c')
110
+ audio_encoder_hidden_states = rearrange(audio_encoder_hidden_states, 'b f n c -> (b f) n c')
111
+ if audio_attention_mask is not None:
112
+ audio_attention_mask = rearrange(audio_attention_mask, 'b f n -> (b f) 1 n')
113
+
114
+ batch, channel, height, weight = hidden_states.shape
115
+ residual = hidden_states
116
+
117
+ hidden_states = self.norm(hidden_states)
118
+ if not self.use_linear_projection:
119
+ hidden_states = self.proj_in(hidden_states)
120
+ inner_dim = hidden_states.shape[1]
121
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
122
+ else:
123
+ inner_dim = hidden_states.shape[1]
124
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
125
+ hidden_states = self.proj_in(hidden_states)
126
+
127
+ # Blocks
128
+ for block in self.transformer_blocks:
129
+ hidden_states = block(
130
+ hidden_states,
131
+ encoder_hidden_states=encoder_hidden_states,
132
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
133
+ audio_attention_mask=audio_attention_mask,
134
+ timestep=timestep,
135
+ video_length=video_length,
136
+ cross_attention_kwargs=cross_attention_kwargs,
137
+ class_labels=class_labels
138
+ )
139
+
140
+ # Output
141
+ if not self.use_linear_projection:
142
+ hidden_states = (
143
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
144
+ )
145
+ hidden_states = self.proj_out(hidden_states)
146
+ else:
147
+ hidden_states = self.proj_out(hidden_states)
148
+ hidden_states = (
149
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
150
+ )
151
+
152
+ output = hidden_states + residual
153
+
154
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
155
+ if not return_dict:
156
+ return (output,)
157
+
158
+ return SpatioTempTransformer3DModelOutput(sample=output)
159
+
160
+
161
+ class BasicTransformerBlock(nn.Module):
162
+ def __init__(
163
+ self,
164
+ dim: int,
165
+ num_attention_heads: int,
166
+ attention_head_dim: int,
167
+ dropout=0.0,
168
+ cross_attention_dim: Optional[int] = None,
169
+ audio_cross_attention_dim: Optional[int] = None,
170
+ activation_fn: str = "geglu",
171
+ num_embeds_ada_norm: Optional[int] = None,
172
+ attention_bias: bool = False,
173
+ only_cross_attention: bool = False,
174
+ double_self_attention: bool = False,
175
+ upcast_attention: bool = False,
176
+ norm_elementwise_affine: bool = True,
177
+ norm_type: str = "layer_norm",
178
+ final_dropout: bool = False,
179
+ ):
180
+ super().__init__()
181
+ self.only_cross_attention = only_cross_attention
182
+
183
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
184
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
185
+
186
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
187
+ raise ValueError(
188
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
189
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
190
+ )
191
+
192
+ # Define 3 blocks. Each block has its own normalization layer.
193
+ # 1. SC-Cross-Attn
194
+ if self.use_ada_layer_norm:
195
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
196
+ elif self.use_ada_layer_norm_zero:
197
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
198
+ else:
199
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
200
+ self.attn1 = FFAttention(
201
+ query_dim=dim,
202
+ heads=num_attention_heads,
203
+ dim_head=attention_head_dim,
204
+ dropout=dropout,
205
+ bias=attention_bias,
206
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
207
+ upcast_attention=upcast_attention,
208
+ )
209
+
210
+ # 2. Audio Conditioned Cross-Attn
211
+ self.norm_audio = (
212
+ AdaLayerNorm(dim, num_embeds_ada_norm)
213
+ if self.use_ada_layer_norm
214
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215
+ )
216
+ self.attn_audio = Attention(
217
+ query_dim=dim,
218
+ cross_attention_dim=audio_cross_attention_dim,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ )
225
+
226
+ # 3. Cross-Attn
227
+ if cross_attention_dim is not None or double_self_attention:
228
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
229
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
230
+ # the second cross attention block.
231
+ self.norm2 = (
232
+ AdaLayerNorm(dim, num_embeds_ada_norm)
233
+ if self.use_ada_layer_norm
234
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
235
+ )
236
+ self.attn2 = Attention(
237
+ query_dim=dim,
238
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
239
+ heads=num_attention_heads,
240
+ dim_head=attention_head_dim,
241
+ dropout=dropout,
242
+ bias=attention_bias,
243
+ upcast_attention=upcast_attention,
244
+ ) # is self-attn if encoder_hidden_states is none
245
+ else:
246
+ self.norm2 = None
247
+ self.attn2 = None
248
+
249
+ # 4. Temp-Attn
250
+ self.pos_proj_temp = Timesteps(dim, flip_sin_to_cos=True, downscale_freq_shift=0)
251
+ self.pos_embedding_temp = TimestepEmbedding(
252
+ dim,
253
+ dim,
254
+ act_fn="silu",
255
+ post_act_fn=None,
256
+ cond_proj_dim=None,
257
+ )
258
+
259
+ self.attn_temp = Attention(
260
+ query_dim=dim,
261
+ heads=num_attention_heads,
262
+ dim_head=attention_head_dim,
263
+ dropout=dropout,
264
+ bias=attention_bias,
265
+ upcast_attention=upcast_attention,
266
+ )
267
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
268
+ self.norm_temp = (
269
+ AdaLayerNorm(dim, num_embeds_ada_norm)
270
+ if self.use_ada_layer_norm
271
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
272
+ )
273
+
274
+ # 5. Feed-forward
275
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
276
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
277
+
278
+ def forward(
279
+ self,
280
+ hidden_states,
281
+ attention_mask=None,
282
+ encoder_hidden_states=None,
283
+ encoder_attention_mask=None,
284
+ audio_encoder_hidden_states=None,
285
+ audio_attention_mask=None,
286
+ timestep=None,
287
+ video_length=None,
288
+ cross_attention_kwargs=None,
289
+ class_labels=None,
290
+ ):
291
+ # Notice that normalization is always applied before the real computation in the following blocks.
292
+ # 1. Self-Attention
293
+ if self.use_ada_layer_norm:
294
+ norm_hidden_states = self.norm1(hidden_states, timestep)
295
+ elif self.use_ada_layer_norm_zero:
296
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
297
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
298
+ )
299
+ else:
300
+ norm_hidden_states = self.norm1(hidden_states)
301
+
302
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
303
+ attn_output = self.attn1(
304
+ norm_hidden_states,
305
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
306
+ attention_mask=attention_mask,
307
+ video_length=video_length,
308
+ **cross_attention_kwargs,
309
+ )
310
+ if self.use_ada_layer_norm_zero:
311
+ attn_output = gate_msa.unsqueeze(1) * attn_output
312
+ hidden_states = attn_output + hidden_states
313
+
314
+ # 2. Audio Cross-Attention
315
+ if self.attn_audio is not None:
316
+ norm_hidden_states = (
317
+ self.norm_audio(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_audio(hidden_states)
318
+ )
319
+ attn_output = self.attn_audio(
320
+ norm_hidden_states,
321
+ encoder_hidden_states=audio_encoder_hidden_states,
322
+ attention_mask=audio_attention_mask,
323
+ **cross_attention_kwargs,
324
+ )
325
+ hidden_states = attn_output + hidden_states
326
+
327
+ # 3. Cross-Attention
328
+ if self.attn2 is not None:
329
+ norm_hidden_states = (
330
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
331
+ )
332
+ # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
333
+ # prepare attention mask here
334
+
335
+ attn_output = self.attn2(
336
+ norm_hidden_states,
337
+ encoder_hidden_states=encoder_hidden_states,
338
+ attention_mask=encoder_attention_mask,
339
+ **cross_attention_kwargs,
340
+ )
341
+ hidden_states = attn_output + hidden_states
342
+
343
+ # 3. Temporal-Attention
344
+
345
+ # Add positional embedding
346
+ device = hidden_states.device
347
+ dtype = hidden_states.dtype
348
+ pos_embed = self.pos_proj_temp(torch.arange(video_length).long()).to(device=device, dtype=dtype) # (f c)
349
+ pos_embed = self.pos_embedding_temp(pos_embed).unsqueeze(0) # (1, f, c)
350
+
351
+ seq_len = hidden_states.shape[1]
352
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
353
+ norm_hidden_states = (
354
+ self.norm_temp(hidden_states + pos_embed, timestep) if self.use_ada_layer_norm else self.norm_temp(
355
+ hidden_states + pos_embed)
356
+ )
357
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
358
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=seq_len)
359
+
360
+ # 4. Feed-forward
361
+ norm_hidden_states = self.norm3(hidden_states)
362
+
363
+ if self.use_ada_layer_norm_zero:
364
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
365
+
366
+ ff_output = self.ff(norm_hidden_states)
367
+
368
+ if self.use_ada_layer_norm_zero:
369
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
370
+
371
+ hidden_states = ff_output + hidden_states
372
+
373
+ return hidden_states
374
+
ff_spatio_temp_resnet_3d.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+ from einops import rearrange
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from unet_utils import FFInflatedConv3d
8
+
9
+
10
+ class FFSpatioTempResUpsample3D(nn.Module):
11
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
12
+ super().__init__()
13
+ self.channels = channels
14
+ self.out_channels = out_channels or channels
15
+ self.use_conv = use_conv
16
+ self.use_conv_transpose = use_conv_transpose
17
+ self.name = name
18
+
19
+ conv = None
20
+ if use_conv_transpose:
21
+ raise NotImplementedError
22
+ elif use_conv:
23
+ conv = FFInflatedConv3d(self.channels, self.out_channels, 3, padding=1)
24
+
25
+ if name == "conv":
26
+ self.conv = conv
27
+ else:
28
+ self.Conv2d_0 = conv
29
+
30
+ def forward(self, hidden_states, output_size=None):
31
+ assert hidden_states.shape[1] == self.channels
32
+
33
+ if self.use_conv_transpose:
34
+ raise NotImplementedError
35
+
36
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
37
+ dtype = hidden_states.dtype
38
+ if dtype == torch.bfloat16:
39
+ hidden_states = hidden_states.to(torch.float32)
40
+
41
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
42
+ if hidden_states.shape[0] >= 64:
43
+ hidden_states = hidden_states.contiguous()
44
+
45
+ # if `output_size` is passed we force the interpolation output
46
+ # size and do not make use of `scale_factor=2`
47
+ if output_size is None:
48
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
49
+ else:
50
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
51
+
52
+ # If the input is bfloat16, we cast back to bfloat16
53
+ if dtype == torch.bfloat16:
54
+ hidden_states = hidden_states.to(dtype)
55
+
56
+ if self.use_conv:
57
+ if self.name == "conv":
58
+ hidden_states = self.conv(hidden_states)
59
+ else:
60
+ hidden_states = self.Conv2d_0(hidden_states)
61
+
62
+ return hidden_states
63
+
64
+
65
+ class FFSpatioTempResDownsample3D(nn.Module):
66
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
67
+ super().__init__()
68
+ self.channels = channels
69
+ self.out_channels = out_channels or channels
70
+ self.use_conv = use_conv
71
+ self.padding = padding
72
+ stride = 2
73
+ self.name = name
74
+
75
+ if use_conv:
76
+ conv = FFInflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
77
+ else:
78
+ raise NotImplementedError
79
+
80
+ if name == "conv":
81
+ self.Conv2d_0 = conv
82
+ self.conv = conv
83
+ elif name == "Conv2d_0":
84
+ self.conv = conv
85
+ else:
86
+ self.conv = conv
87
+
88
+ def forward(self, hidden_states):
89
+ assert hidden_states.shape[1] == self.channels
90
+ if self.use_conv and self.padding == 0:
91
+ raise NotImplementedError
92
+
93
+ assert hidden_states.shape[1] == self.channels
94
+ hidden_states = self.conv(hidden_states)
95
+
96
+ return hidden_states
97
+
98
+
99
+ class FFSpatioTempResnetBlock3D(nn.Module):
100
+ def __init__(
101
+ self,
102
+ *,
103
+ in_channels,
104
+ out_channels=None,
105
+ conv_shortcut=False,
106
+ dropout=0.0,
107
+ temb_channels=512,
108
+ groups=32,
109
+ groups_out=None,
110
+ pre_norm=True,
111
+ eps=1e-6,
112
+ non_linearity="swish",
113
+ time_embedding_norm="default",
114
+ output_scale_factor=1.0,
115
+ use_in_shortcut=None
116
+ ):
117
+ super().__init__()
118
+ self.pre_norm = pre_norm
119
+ self.pre_norm = True
120
+ self.in_channels = in_channels
121
+ out_channels = in_channels if out_channels is None else out_channels
122
+ self.out_channels = out_channels
123
+ self.use_conv_shortcut = conv_shortcut
124
+ self.time_embedding_norm = time_embedding_norm
125
+ self.output_scale_factor = output_scale_factor
126
+
127
+ if groups_out is None:
128
+ groups_out = groups
129
+
130
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
131
+
132
+ self.conv1 = FFInflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
133
+
134
+ if temb_channels is not None:
135
+ if self.time_embedding_norm == "default":
136
+ time_emb_proj_out_channels = out_channels
137
+ elif self.time_embedding_norm == "scale_shift":
138
+ time_emb_proj_out_channels = out_channels * 2
139
+ else:
140
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
141
+
142
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
143
+ else:
144
+ self.time_emb_proj = None
145
+
146
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
147
+ self.dropout = torch.nn.Dropout(dropout)
148
+ self.conv2 = FFInflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
149
+
150
+ if non_linearity == "swish":
151
+ self.nonlinearity = lambda x: F.silu(x)
152
+ elif non_linearity == "silu":
153
+ self.nonlinearity = nn.SiLU()
154
+
155
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
156
+
157
+ self.conv_shortcut = None
158
+ if self.use_in_shortcut:
159
+ self.conv_shortcut = FFInflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
160
+
161
+ def forward(self, input_tensor, temb):
162
+ hidden_states = input_tensor
163
+
164
+ hidden_states = self.norm1(hidden_states)
165
+ hidden_states = self.nonlinearity(hidden_states)
166
+
167
+ hidden_states = self.conv1(hidden_states)
168
+
169
+ if temb is not None:
170
+ temb = rearrange(self.time_emb_proj(self.nonlinearity(temb)), "b f c -> b c f")[:, :, :, None, None]
171
+
172
+ if temb is not None and self.time_embedding_norm == "default":
173
+ hidden_states = hidden_states + temb
174
+
175
+ hidden_states = self.norm2(hidden_states)
176
+
177
+ if temb is not None and self.time_embedding_norm == "scale_shift":
178
+ scale, shift = torch.chunk(temb, 2, dim=1)
179
+ hidden_states = hidden_states * (1 + scale) + shift
180
+
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.dropout(hidden_states)
184
+ hidden_states = self.conv2(hidden_states)
185
+
186
+ if self.conv_shortcut is not None:
187
+ input_tensor = self.conv_shortcut(input_tensor)
188
+
189
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
190
+
191
+ return output_tensor
ff_spatio_temp_transformer_3d.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+ from einops import rearrange
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import Attention
15
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero
16
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
17
+
18
+ from unet_utils import FFAttention
19
+
20
+
21
+ @dataclass
22
+ class SpatioTempTransformer3DModelOutput(BaseOutput):
23
+ sample: torch.Tensor
24
+
25
+
26
+ if is_xformers_available():
27
+ import xformers
28
+ import xformers.ops
29
+ else:
30
+ xformers = None
31
+
32
+
33
+ class FFSpatioTempTransformer3DModel(ModelMixin, ConfigMixin):
34
+ @register_to_config
35
+ def __init__(
36
+ self,
37
+ num_attention_heads: int = 16,
38
+ attention_head_dim: int = 88,
39
+ in_channels: Optional[int] = None,
40
+ num_layers: int = 1,
41
+ dropout: float = 0.0,
42
+ norm_num_groups: int = 32,
43
+ cross_attention_dim: Optional[int] = None,
44
+ attention_bias: bool = False,
45
+ activation_fn: str = "geglu",
46
+ num_embeds_ada_norm: Optional[int] = None,
47
+ use_linear_projection: bool = False,
48
+ only_cross_attention: bool = False,
49
+ upcast_attention: bool = False,
50
+ ):
51
+ super().__init__()
52
+ self.use_linear_projection = use_linear_projection
53
+ self.num_attention_heads = num_attention_heads
54
+ self.attention_head_dim = attention_head_dim
55
+ inner_dim = num_attention_heads * attention_head_dim
56
+
57
+ # Define input layers
58
+ self.in_channels = in_channels
59
+
60
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
61
+ if use_linear_projection:
62
+ self.proj_in = nn.Linear(in_channels, inner_dim)
63
+ else:
64
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
65
+
66
+ # Define transformers blocks
67
+ self.transformer_blocks = nn.ModuleList(
68
+ [
69
+ BasicTransformerBlock(
70
+ inner_dim,
71
+ num_attention_heads,
72
+ attention_head_dim,
73
+ dropout=dropout,
74
+ cross_attention_dim=cross_attention_dim,
75
+ activation_fn=activation_fn,
76
+ num_embeds_ada_norm=num_embeds_ada_norm,
77
+ attention_bias=attention_bias,
78
+ only_cross_attention=only_cross_attention,
79
+ upcast_attention=upcast_attention,
80
+ )
81
+ for d in range(num_layers)
82
+ ]
83
+ )
84
+
85
+ # 4. Define output layers
86
+ if use_linear_projection:
87
+ self.proj_out = nn.Linear(in_channels, inner_dim)
88
+ else:
89
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
90
+
91
+ def forward(
92
+ self,
93
+ hidden_states,
94
+ encoder_hidden_states=None,
95
+ timestep=None,
96
+ class_labels=None,
97
+ cross_attention_kwargs=None,
98
+ return_dict: bool = True):
99
+ # Input
100
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
101
+ video_length = hidden_states.shape[2]
102
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
103
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b f n c -> (b f) n c')
104
+
105
+ batch, channel, height, weight = hidden_states.shape
106
+ residual = hidden_states
107
+
108
+ hidden_states = self.norm(hidden_states)
109
+ if not self.use_linear_projection:
110
+ hidden_states = self.proj_in(hidden_states)
111
+ inner_dim = hidden_states.shape[1]
112
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
113
+ else:
114
+ inner_dim = hidden_states.shape[1]
115
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
116
+ hidden_states = self.proj_in(hidden_states)
117
+
118
+ # Blocks
119
+ for block in self.transformer_blocks:
120
+ hidden_states = block(
121
+ hidden_states,
122
+ encoder_hidden_states=encoder_hidden_states,
123
+ timestep=timestep,
124
+ video_length=video_length,
125
+ cross_attention_kwargs=cross_attention_kwargs,
126
+ class_labels=class_labels
127
+ )
128
+
129
+ # Output
130
+ if not self.use_linear_projection:
131
+ hidden_states = (
132
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
133
+ )
134
+ hidden_states = self.proj_out(hidden_states)
135
+ else:
136
+ hidden_states = self.proj_out(hidden_states)
137
+ hidden_states = (
138
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
139
+ )
140
+
141
+ output = hidden_states + residual
142
+
143
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
144
+ if not return_dict:
145
+ return (output,)
146
+
147
+ return SpatioTempTransformer3DModelOutput(sample=output)
148
+
149
+
150
+ class BasicTransformerBlock(nn.Module):
151
+ def __init__(
152
+ self,
153
+ dim: int,
154
+ num_attention_heads: int,
155
+ attention_head_dim: int,
156
+ dropout=0.0,
157
+ cross_attention_dim: Optional[int] = None,
158
+ activation_fn: str = "geglu",
159
+ num_embeds_ada_norm: Optional[int] = None,
160
+ attention_bias: bool = False,
161
+ only_cross_attention: bool = False,
162
+ double_self_attention: bool = False,
163
+ upcast_attention: bool = False,
164
+ norm_elementwise_affine: bool = True,
165
+ norm_type: str = "layer_norm",
166
+ final_dropout: bool = False,
167
+ ):
168
+ super().__init__()
169
+ self.only_cross_attention = only_cross_attention
170
+
171
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
172
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
173
+
174
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
175
+ raise ValueError(
176
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
177
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
178
+ )
179
+
180
+ # Define 3 blocks. Each block has its own normalization layer.
181
+ # 1. FF-Attn
182
+ if self.use_ada_layer_norm:
183
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
184
+ elif self.use_ada_layer_norm_zero:
185
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
186
+ else:
187
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
188
+ self.attn1 = FFAttention(
189
+ query_dim=dim,
190
+ heads=num_attention_heads,
191
+ dim_head=attention_head_dim,
192
+ dropout=dropout,
193
+ bias=attention_bias,
194
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
195
+ upcast_attention=upcast_attention,
196
+ )
197
+
198
+ # 2. Cross-Attn
199
+ if cross_attention_dim is not None or double_self_attention:
200
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
201
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
202
+ # the second cross attention block.
203
+ self.norm2 = (
204
+ AdaLayerNorm(dim, num_embeds_ada_norm)
205
+ if self.use_ada_layer_norm
206
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
207
+ )
208
+ self.attn2 = Attention(
209
+ query_dim=dim,
210
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
211
+ heads=num_attention_heads,
212
+ dim_head=attention_head_dim,
213
+ dropout=dropout,
214
+ bias=attention_bias,
215
+ upcast_attention=upcast_attention,
216
+ ) # is self-attn if encoder_hidden_states is none
217
+ else:
218
+ self.norm2 = None
219
+ self.attn2 = None
220
+
221
+ # 3. Temp-Attn
222
+
223
+ self.pos_proj_temp = Timesteps(dim, flip_sin_to_cos=True, downscale_freq_shift=0)
224
+ self.pos_embedding_temp = TimestepEmbedding(
225
+ dim,
226
+ dim,
227
+ act_fn="silu",
228
+ post_act_fn=None,
229
+ cond_proj_dim=None,
230
+ )
231
+
232
+ self.attn_temp = Attention(
233
+ query_dim=dim,
234
+ heads=num_attention_heads,
235
+ dim_head=attention_head_dim,
236
+ dropout=dropout,
237
+ bias=attention_bias,
238
+ upcast_attention=upcast_attention,
239
+ )
240
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
241
+ self.norm_temp = (
242
+ AdaLayerNorm(dim, num_embeds_ada_norm)
243
+ if self.use_ada_layer_norm
244
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
245
+ )
246
+
247
+ # 4. Feed-forward
248
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
249
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states,
254
+ attention_mask=None,
255
+ encoder_hidden_states=None,
256
+ encoder_attention_mask=None,
257
+ timestep=None,
258
+ video_length=None,
259
+ cross_attention_kwargs=None,
260
+ class_labels=None,
261
+ ):
262
+ # Notice that normalization is always applied before the real computation in the following blocks.
263
+ # 1. Self-Attention
264
+ if self.use_ada_layer_norm:
265
+ norm_hidden_states = self.norm1(hidden_states, timestep)
266
+ elif self.use_ada_layer_norm_zero:
267
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
268
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
269
+ )
270
+ else:
271
+ norm_hidden_states = self.norm1(hidden_states)
272
+
273
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
274
+ attn_output = self.attn1(
275
+ norm_hidden_states,
276
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
277
+ attention_mask=attention_mask,
278
+ video_length=video_length,
279
+ **cross_attention_kwargs,
280
+ )
281
+ if self.use_ada_layer_norm_zero:
282
+ attn_output = gate_msa.unsqueeze(1) * attn_output
283
+ hidden_states = attn_output + hidden_states
284
+
285
+ # 2. Cross-Attention
286
+ if self.attn2 is not None:
287
+ norm_hidden_states = (
288
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
289
+ )
290
+ # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
291
+ # prepare attention mask here
292
+
293
+ attn_output = self.attn2(
294
+ norm_hidden_states,
295
+ encoder_hidden_states=encoder_hidden_states,
296
+ attention_mask=encoder_attention_mask,
297
+ **cross_attention_kwargs,
298
+ )
299
+ hidden_states = attn_output + hidden_states
300
+
301
+ # 3. Temporal-Attention
302
+
303
+ # Add positional embedding
304
+ device = hidden_states.device
305
+ dtype = hidden_states.dtype
306
+ pos_embed = self.pos_proj_temp(torch.arange(video_length).long()).to(device=device, dtype=dtype) # (f c)
307
+ pos_embed = self.pos_embedding_temp(pos_embed).unsqueeze(0) # (1, f, c)
308
+
309
+ seq_len = hidden_states.shape[1]
310
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
311
+ norm_hidden_states = (
312
+ self.norm_temp(hidden_states + pos_embed, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states+pos_embed)
313
+ )
314
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
315
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=seq_len)
316
+
317
+ # 4. Feed-forward
318
+ norm_hidden_states = self.norm3(hidden_states)
319
+
320
+ if self.use_ada_layer_norm_zero:
321
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
322
+
323
+ ff_output = self.ff(norm_hidden_states)
324
+
325
+ if self.use_ada_layer_norm_zero:
326
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
327
+
328
+ hidden_states = ff_output + hidden_states
329
+
330
+ return hidden_states
331
+
imagebind/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from imagebind import data
2
+ from imagebind.models import imagebind_model
3
+ from imagebind.models.imagebind_model import ModalityType
imagebind/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (335 Bytes). View file
 
imagebind/__pycache__/data.cpython-310.pyc ADDED
Binary file (9.37 kB). View file
 
imagebind/bpe/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
imagebind/data.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import logging
9
+ import math
10
+ import pkg_resources
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchaudio
15
+ from PIL import Image
16
+ from pytorchvideo import transforms as pv_transforms
17
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
18
+ from pytorchvideo.data.encoded_video import EncodedVideo
19
+ from torchvision import transforms
20
+ from torchvision.transforms._transforms_video import NormalizeVideo
21
+
22
+ from imagebind.models.multimodal_preprocessors import SimpleTokenizer
23
+
24
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
25
+
26
+
27
+ def return_bpe_path():
28
+ return pkg_resources.resource_filename(
29
+ "imagebind", "bpe/bpe_simple_vocab_16e6.txt.gz"
30
+ )
31
+
32
+
33
+ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
34
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
35
+ waveform -= waveform.mean()
36
+ fbank = torchaudio.compliance.kaldi.fbank(
37
+ waveform,
38
+ htk_compat=True,
39
+ sample_frequency=sample_rate,
40
+ use_energy=False,
41
+ window_type="hanning",
42
+ num_mel_bins=num_mel_bins,
43
+ dither=0.0,
44
+ frame_length=25,
45
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
46
+ )
47
+ # Convert to [mel_bins, num_frames] shape
48
+ fbank = fbank.transpose(0, 1)
49
+ # Pad to target_length
50
+ n_frames = fbank.size(1)
51
+ p = target_length - n_frames
52
+ # if p is too large (say >20%), flash a warning
53
+ if abs(p) / n_frames > 0.2:
54
+ logging.warning(
55
+ "Large gap between audio n_frames(%d) and "
56
+ "target_length (%d). Is the audio_target_length "
57
+ "setting correct?",
58
+ n_frames,
59
+ target_length,
60
+ )
61
+ # cut and pad
62
+ if p > 0:
63
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
64
+ elif p < 0:
65
+ fbank = fbank[:, 0:target_length]
66
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
67
+ # channel image
68
+ fbank = fbank.unsqueeze(0)
69
+ return fbank
70
+
71
+
72
+ def get_clip_timepoints(clip_sampler, duration):
73
+ # Read out all clips in this video
74
+ all_clips_timepoints = []
75
+ is_last_clip = False
76
+ end = 0.0
77
+ while not is_last_clip:
78
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
79
+ all_clips_timepoints.append((start, end))
80
+ return all_clips_timepoints
81
+
82
+
83
+ def load_and_transform_vision_data(image_paths, device):
84
+ if image_paths is None:
85
+ return None
86
+
87
+ image_outputs = []
88
+
89
+ data_transform = transforms.Compose(
90
+ [
91
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
92
+ transforms.CenterCrop(224),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(
95
+ mean=(0.48145466, 0.4578275, 0.40821073),
96
+ std=(0.26862954, 0.26130258, 0.27577711),
97
+ ),
98
+ ]
99
+ )
100
+
101
+ for image_path in image_paths:
102
+ with open(image_path, "rb") as fopen:
103
+ image = Image.open(fopen).convert("RGB")
104
+
105
+ image = data_transform(image).to(device)
106
+ image_outputs.append(image)
107
+ return torch.stack(image_outputs, dim=0)
108
+
109
+
110
+ def load_and_transform_text(text, device):
111
+ if text is None:
112
+ return None
113
+ tokenizer = SimpleTokenizer(bpe_path=return_bpe_path())
114
+ tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
115
+ tokens = torch.cat(tokens, dim=0)
116
+ return tokens
117
+
118
+
119
+ def load_and_transform_audio_data(
120
+ audio_paths,
121
+ device,
122
+ num_mel_bins=128,
123
+ target_length=204,
124
+ sample_rate=16000,
125
+ clip_duration=2,
126
+ clips_per_video=3,
127
+ mean=-4.268,
128
+ std=9.138,
129
+ ):
130
+ if audio_paths is None:
131
+ return None
132
+
133
+ audio_outputs = []
134
+ clip_sampler = ConstantClipsPerVideoSampler(
135
+ clip_duration=clip_duration, clips_per_video=clips_per_video
136
+ )
137
+
138
+ for audio_path in audio_paths:
139
+ waveform, sr = torchaudio.load(audio_path)
140
+ if sample_rate != sr:
141
+ waveform = torchaudio.functional.resample(
142
+ waveform, orig_freq=sr, new_freq=sample_rate
143
+ )
144
+ all_clips_timepoints = get_clip_timepoints(
145
+ clip_sampler, waveform.size(1) / sample_rate
146
+ )
147
+ all_clips = []
148
+ for clip_timepoints in all_clips_timepoints:
149
+ waveform_clip = waveform[
150
+ :,
151
+ int(clip_timepoints[0] * sample_rate) : int(
152
+ clip_timepoints[1] * sample_rate
153
+ ),
154
+ ]
155
+ waveform_melspec = waveform2melspec(
156
+ waveform_clip, sample_rate, num_mel_bins, target_length
157
+ )
158
+ all_clips.append(waveform_melspec)
159
+
160
+ normalize = transforms.Normalize(mean=mean, std=std)
161
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
162
+
163
+ all_clips = torch.stack(all_clips, dim=0)
164
+ audio_outputs.append(all_clips)
165
+
166
+ return torch.stack(audio_outputs, dim=0)
167
+
168
+
169
+ def crop_boxes(boxes, x_offset, y_offset):
170
+ """
171
+ Perform crop on the bounding boxes given the offsets.
172
+ Args:
173
+ boxes (ndarray or None): bounding boxes to perform crop. The dimension
174
+ is `num boxes` x 4.
175
+ x_offset (int): cropping offset in the x axis.
176
+ y_offset (int): cropping offset in the y axis.
177
+ Returns:
178
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
179
+ `num boxes` x 4.
180
+ """
181
+ cropped_boxes = boxes.copy()
182
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
183
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
184
+
185
+ return cropped_boxes
186
+
187
+
188
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
189
+ """
190
+ Perform uniform spatial sampling on the images and corresponding boxes.
191
+ Args:
192
+ images (tensor): images to perform uniform crop. The dimension is
193
+ `num frames` x `channel` x `height` x `width`.
194
+ size (int): size of height and weight to crop the images.
195
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
196
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
197
+ crop if height is larger than width.
198
+ boxes (ndarray or None): optional. Corresponding boxes to images.
199
+ Dimension is `num boxes` x 4.
200
+ scale_size (int): optinal. If not None, resize the images to scale_size before
201
+ performing any crop.
202
+ Returns:
203
+ cropped (tensor): images with dimension of
204
+ `num frames` x `channel` x `size` x `size`.
205
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
206
+ `num boxes` x 4.
207
+ """
208
+ assert spatial_idx in [0, 1, 2]
209
+ ndim = len(images.shape)
210
+ if ndim == 3:
211
+ images = images.unsqueeze(0)
212
+ height = images.shape[2]
213
+ width = images.shape[3]
214
+
215
+ if scale_size is not None:
216
+ if width <= height:
217
+ width, height = scale_size, int(height / width * scale_size)
218
+ else:
219
+ width, height = int(width / height * scale_size), scale_size
220
+ images = torch.nn.functional.interpolate(
221
+ images,
222
+ size=(height, width),
223
+ mode="bilinear",
224
+ align_corners=False,
225
+ )
226
+
227
+ y_offset = int(math.ceil((height - size) / 2))
228
+ x_offset = int(math.ceil((width - size) / 2))
229
+
230
+ if height > width:
231
+ if spatial_idx == 0:
232
+ y_offset = 0
233
+ elif spatial_idx == 2:
234
+ y_offset = height - size
235
+ else:
236
+ if spatial_idx == 0:
237
+ x_offset = 0
238
+ elif spatial_idx == 2:
239
+ x_offset = width - size
240
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
241
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
242
+ if ndim == 3:
243
+ cropped = cropped.squeeze(0)
244
+ return cropped, cropped_boxes
245
+
246
+
247
+ class SpatialCrop(nn.Module):
248
+ """
249
+ Convert the video into 3 smaller clips spatially. Must be used after the
250
+ temporal crops to get spatial crops, and should be used with
251
+ -2 in the spatial crop at the slowfast augmentation stage (so full
252
+ frames are passed in here). Will return a larger list with the
253
+ 3x spatial crops as well.
254
+ """
255
+
256
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
257
+ super().__init__()
258
+ self.crop_size = crop_size
259
+ if num_crops == 3:
260
+ self.crops_to_ext = [0, 1, 2]
261
+ self.flipped_crops_to_ext = []
262
+ elif num_crops == 1:
263
+ self.crops_to_ext = [1]
264
+ self.flipped_crops_to_ext = []
265
+ else:
266
+ raise NotImplementedError("Nothing else supported yet")
267
+
268
+ def forward(self, videos):
269
+ """
270
+ Args:
271
+ videos: A list of C, T, H, W videos.
272
+ Returns:
273
+ videos: A list with 3x the number of elements. Each video converted
274
+ to C, T, H', W' by spatial cropping.
275
+ """
276
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
277
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
278
+ res = []
279
+ for video in videos:
280
+ for spatial_idx in self.crops_to_ext:
281
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
282
+ if not self.flipped_crops_to_ext:
283
+ continue
284
+ flipped_video = transforms.functional.hflip(video)
285
+ for spatial_idx in self.flipped_crops_to_ext:
286
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
287
+ return res
288
+
289
+
290
+ def load_and_transform_video_data(
291
+ video_paths,
292
+ device,
293
+ clip_duration=2,
294
+ clips_per_video=5,
295
+ sample_rate=16000,
296
+ ):
297
+ if video_paths is None:
298
+ return None
299
+
300
+ video_outputs = []
301
+ video_transform = transforms.Compose(
302
+ [
303
+ pv_transforms.ShortSideScale(224),
304
+ NormalizeVideo(
305
+ mean=(0.48145466, 0.4578275, 0.40821073),
306
+ std=(0.26862954, 0.26130258, 0.27577711),
307
+ ),
308
+ ]
309
+ )
310
+
311
+ clip_sampler = ConstantClipsPerVideoSampler(
312
+ clip_duration=clip_duration, clips_per_video=clips_per_video
313
+ )
314
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
315
+
316
+ for video_path in video_paths:
317
+ video = EncodedVideo.from_path(
318
+ video_path,
319
+ decoder="decord",
320
+ decode_audio=False,
321
+ **{"sample_rate": sample_rate},
322
+ )
323
+
324
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
325
+
326
+ all_video = []
327
+ for clip_timepoints in all_clips_timepoints:
328
+ # Read the clip, get frames
329
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
330
+ if clip is None:
331
+ raise ValueError("No clip found")
332
+ video_clip = frame_sampler(clip["video"])
333
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
334
+
335
+ all_video.append(video_clip)
336
+
337
+ all_video = [video_transform(clip) for clip in all_video]
338
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
339
+
340
+ all_video = torch.stack(all_video, dim=0)
341
+ video_outputs.append(all_video)
342
+
343
+ return torch.stack(video_outputs, dim=0).to(device)
imagebind/models/__init__.py ADDED
File without changes
imagebind/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (169 Bytes). View file
 
imagebind/models/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (5.14 kB). View file
 
imagebind/models/__pycache__/imagebind_model.cpython-310.pyc ADDED
Binary file (8.3 kB). View file
 
imagebind/models/__pycache__/multimodal_preprocessors.cpython-310.pyc ADDED
Binary file (19.9 kB). View file
 
imagebind/models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (8.01 kB). View file
 
imagebind/models/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import einops
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class Normalize(nn.Module):
16
+ def __init__(self, dim: int) -> None:
17
+ super().__init__()
18
+ self.dim = dim
19
+
20
+ def forward(self, x):
21
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
22
+
23
+
24
+ class LearnableLogitScaling(nn.Module):
25
+ def __init__(
26
+ self,
27
+ logit_scale_init: float = 1 / 0.07,
28
+ learnable: bool = True,
29
+ max_logit_scale: float = 100,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.max_logit_scale = max_logit_scale
33
+ self.logit_scale_init = logit_scale_init
34
+ self.learnable = learnable
35
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
36
+ if learnable:
37
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
38
+ else:
39
+ self.register_buffer("log_logit_scale", log_logit_scale)
40
+
41
+ def forward(self, x):
42
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
43
+
44
+ def extra_repr(self):
45
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," \
46
+ f" max_logit_scale={self.max_logit_scale}"
47
+ return st
48
+
49
+
50
+ class EinOpsRearrange(nn.Module):
51
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
52
+ super().__init__()
53
+ self.rearrange_expr = rearrange_expr
54
+ self.kwargs = kwargs
55
+
56
+ def forward(self, x):
57
+ assert isinstance(x, torch.Tensor)
58
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
59
+
60
+
61
+ class VerboseNNModule(nn.Module):
62
+ """
63
+ Wrapper around nn.Module that prints registered buffers and parameter names.
64
+ """
65
+
66
+ @staticmethod
67
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
68
+ st = (
69
+ "("
70
+ + name
71
+ + "): "
72
+ + "tensor("
73
+ + str(tuple(tensor[1].shape))
74
+ + ", requires_grad="
75
+ + str(tensor[1].requires_grad)
76
+ + ")\n"
77
+ )
78
+ return st
79
+
80
+ def extra_repr(self) -> str:
81
+ named_modules = set()
82
+ for p in self.named_modules():
83
+ named_modules.update([p[0]])
84
+ named_modules = list(named_modules)
85
+
86
+ string_repr = ""
87
+ for p in self.named_parameters():
88
+ name = p[0].split(".")[0]
89
+ if name not in named_modules:
90
+ string_repr += self.get_readable_tensor_repr(name, p)
91
+
92
+ for p in self.named_buffers():
93
+ name = p[0].split(".")[0]
94
+ string_repr += self.get_readable_tensor_repr(name, p)
95
+
96
+ return string_repr
97
+
98
+
99
+ def cast_if_src_dtype(
100
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
101
+ ):
102
+ updated = False
103
+ if tensor.dtype == src_dtype:
104
+ tensor = tensor.to(dtype=tgt_dtype)
105
+ updated = True
106
+ return tensor, updated
107
+
108
+
109
+ class QuickGELU(nn.Module):
110
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
111
+ def forward(self, x: torch.Tensor):
112
+ return x * torch.sigmoid(1.702 * x)
113
+
114
+
115
+ class SelectElement(nn.Module):
116
+ def __init__(self, index) -> None:
117
+ super().__init__()
118
+ self.index = index
119
+
120
+ def forward(self, x):
121
+ assert x.ndim >= 3
122
+ return x[:, self.index, ...]
123
+
124
+
125
+ class SelectEOSAndProject(nn.Module):
126
+ """
127
+ Text Pooling used in OpenCLIP
128
+ """
129
+
130
+ def __init__(self, proj: nn.Module) -> None:
131
+ super().__init__()
132
+ self.proj = proj
133
+
134
+ def forward(self, x, seq_len):
135
+ assert x.ndim == 3
136
+ # x is of shape B x L x D
137
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
138
+ x = x[torch.arange(x.shape[0]), seq_len]
139
+ x = self.proj(x)
140
+ return x
imagebind/models/imagebind_model.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import os
10
+ from functools import partial
11
+ from types import SimpleNamespace
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from imagebind.models.helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize,
17
+ SelectElement, SelectEOSAndProject)
18
+ from imagebind.models.multimodal_preprocessors import (AudioPreprocessor,
19
+ IMUPreprocessor, PadIm2Video,
20
+ PatchEmbedGeneric,
21
+ RGBDTPreprocessor,
22
+ SpatioTemporalPosEmbeddingHelper,
23
+ TextPreprocessor,
24
+ ThermalPreprocessor)
25
+ from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
26
+
27
+ ModalityType = SimpleNamespace(
28
+ VISION="vision",
29
+ TEXT="text",
30
+ AUDIO="audio",
31
+ THERMAL="thermal",
32
+ DEPTH="depth",
33
+ IMU="imu",
34
+ )
35
+
36
+
37
+ class ImageBindModel(nn.Module):
38
+ def __init__(
39
+ self,
40
+ video_frames=2,
41
+ kernel_size=(2, 14, 14),
42
+ audio_kernel_size=16,
43
+ audio_stride=10,
44
+ out_embed_dim=768,
45
+ vision_embed_dim=1024,
46
+ vision_num_blocks=24,
47
+ vision_num_heads=16,
48
+ audio_embed_dim=768,
49
+ audio_num_blocks=12,
50
+ audio_num_heads=12,
51
+ audio_num_mel_bins=128,
52
+ audio_target_len=204,
53
+ audio_drop_path=0.1,
54
+ text_embed_dim=768,
55
+ text_num_blocks=12,
56
+ text_num_heads=12,
57
+ depth_embed_dim=384,
58
+ depth_kernel_size=16,
59
+ depth_num_blocks=12,
60
+ depth_num_heads=8,
61
+ depth_drop_path=0.0,
62
+ thermal_embed_dim=768,
63
+ thermal_kernel_size=16,
64
+ thermal_num_blocks=12,
65
+ thermal_num_heads=12,
66
+ thermal_drop_path=0.0,
67
+ imu_embed_dim=512,
68
+ imu_kernel_size=8,
69
+ imu_num_blocks=6,
70
+ imu_num_heads=8,
71
+ imu_drop_path=0.7,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.modality_preprocessors = self._create_modality_preprocessors(
76
+ video_frames,
77
+ vision_embed_dim,
78
+ kernel_size,
79
+ text_embed_dim,
80
+ audio_embed_dim,
81
+ audio_kernel_size,
82
+ audio_stride,
83
+ audio_num_mel_bins,
84
+ audio_target_len,
85
+ depth_embed_dim,
86
+ depth_kernel_size,
87
+ thermal_embed_dim,
88
+ thermal_kernel_size,
89
+ imu_embed_dim,
90
+ )
91
+
92
+ self.modality_trunks = self._create_modality_trunks(
93
+ vision_embed_dim,
94
+ vision_num_blocks,
95
+ vision_num_heads,
96
+ text_embed_dim,
97
+ text_num_blocks,
98
+ text_num_heads,
99
+ audio_embed_dim,
100
+ audio_num_blocks,
101
+ audio_num_heads,
102
+ audio_drop_path,
103
+ depth_embed_dim,
104
+ depth_num_blocks,
105
+ depth_num_heads,
106
+ depth_drop_path,
107
+ thermal_embed_dim,
108
+ thermal_num_blocks,
109
+ thermal_num_heads,
110
+ thermal_drop_path,
111
+ imu_embed_dim,
112
+ imu_num_blocks,
113
+ imu_num_heads,
114
+ imu_drop_path,
115
+ )
116
+
117
+ self.modality_heads = self._create_modality_heads(
118
+ out_embed_dim,
119
+ vision_embed_dim,
120
+ text_embed_dim,
121
+ audio_embed_dim,
122
+ depth_embed_dim,
123
+ thermal_embed_dim,
124
+ imu_embed_dim,
125
+ )
126
+
127
+ self.modality_postprocessors = self._create_modality_postprocessors(
128
+ out_embed_dim
129
+ )
130
+
131
+ def _create_modality_preprocessors(
132
+ self,
133
+ video_frames=2,
134
+ vision_embed_dim=1024,
135
+ kernel_size=(2, 14, 14),
136
+ text_embed_dim=768,
137
+ audio_embed_dim=768,
138
+ audio_kernel_size=16,
139
+ audio_stride=10,
140
+ audio_num_mel_bins=128,
141
+ audio_target_len=204,
142
+ depth_embed_dim=768,
143
+ depth_kernel_size=16,
144
+ thermal_embed_dim=768,
145
+ thermal_kernel_size=16,
146
+ imu_embed_dim=512,
147
+ ):
148
+ rgbt_stem = PatchEmbedGeneric(
149
+ proj_stem=[
150
+ PadIm2Video(pad_type="repeat", ntimes=2),
151
+ nn.Conv3d(
152
+ in_channels=3,
153
+ kernel_size=kernel_size,
154
+ out_channels=vision_embed_dim,
155
+ stride=kernel_size,
156
+ bias=False,
157
+ ),
158
+ ]
159
+ )
160
+ rgbt_preprocessor = RGBDTPreprocessor(
161
+ img_size=[3, video_frames, 224, 224],
162
+ num_cls_tokens=1,
163
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
164
+ rgbt_stem=rgbt_stem,
165
+ depth_stem=None,
166
+ )
167
+
168
+ text_preprocessor = TextPreprocessor(
169
+ context_length=77,
170
+ vocab_size=49408,
171
+ embed_dim=text_embed_dim,
172
+ causal_masking=True,
173
+ )
174
+
175
+ audio_stem = PatchEmbedGeneric(
176
+ proj_stem=[
177
+ nn.Conv2d(
178
+ in_channels=1,
179
+ kernel_size=audio_kernel_size,
180
+ stride=audio_stride,
181
+ out_channels=audio_embed_dim,
182
+ bias=False,
183
+ ),
184
+ ],
185
+ norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
186
+ )
187
+ audio_preprocessor = AudioPreprocessor(
188
+ img_size=[1, audio_num_mel_bins, audio_target_len],
189
+ num_cls_tokens=1,
190
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
191
+ audio_stem=audio_stem,
192
+ )
193
+
194
+ depth_stem = PatchEmbedGeneric(
195
+ [
196
+ nn.Conv2d(
197
+ kernel_size=depth_kernel_size,
198
+ in_channels=1,
199
+ out_channels=depth_embed_dim,
200
+ stride=depth_kernel_size,
201
+ bias=False,
202
+ ),
203
+ ],
204
+ norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
205
+ )
206
+
207
+ depth_preprocessor = RGBDTPreprocessor(
208
+ img_size=[1, 224, 224],
209
+ num_cls_tokens=1,
210
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
211
+ rgbt_stem=None,
212
+ depth_stem=depth_stem,
213
+ )
214
+
215
+ thermal_stem = PatchEmbedGeneric(
216
+ [
217
+ nn.Conv2d(
218
+ kernel_size=thermal_kernel_size,
219
+ in_channels=1,
220
+ out_channels=thermal_embed_dim,
221
+ stride=thermal_kernel_size,
222
+ bias=False,
223
+ ),
224
+ ],
225
+ norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
226
+ )
227
+ thermal_preprocessor = ThermalPreprocessor(
228
+ img_size=[1, 224, 224],
229
+ num_cls_tokens=1,
230
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
231
+ thermal_stem=thermal_stem,
232
+ )
233
+
234
+ imu_stem = PatchEmbedGeneric(
235
+ [
236
+ nn.Linear(
237
+ in_features=48,
238
+ out_features=imu_embed_dim,
239
+ bias=False,
240
+ ),
241
+ ],
242
+ norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
243
+ )
244
+
245
+ imu_preprocessor = IMUPreprocessor(
246
+ img_size=[6, 2000],
247
+ num_cls_tokens=1,
248
+ kernel_size=8,
249
+ embed_dim=imu_embed_dim,
250
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
251
+ imu_stem=imu_stem,
252
+ )
253
+
254
+ modality_preprocessors = {
255
+ ModalityType.VISION: rgbt_preprocessor,
256
+ ModalityType.TEXT: text_preprocessor,
257
+ ModalityType.AUDIO: audio_preprocessor,
258
+ ModalityType.DEPTH: depth_preprocessor,
259
+ ModalityType.THERMAL: thermal_preprocessor,
260
+ ModalityType.IMU: imu_preprocessor,
261
+ }
262
+
263
+ return nn.ModuleDict(modality_preprocessors)
264
+
265
+ def _create_modality_trunks(
266
+ self,
267
+ vision_embed_dim=1024,
268
+ vision_num_blocks=24,
269
+ vision_num_heads=16,
270
+ text_embed_dim=768,
271
+ text_num_blocks=12,
272
+ text_num_heads=12,
273
+ audio_embed_dim=768,
274
+ audio_num_blocks=12,
275
+ audio_num_heads=12,
276
+ audio_drop_path=0.0,
277
+ depth_embed_dim=768,
278
+ depth_num_blocks=12,
279
+ depth_num_heads=12,
280
+ depth_drop_path=0.0,
281
+ thermal_embed_dim=768,
282
+ thermal_num_blocks=12,
283
+ thermal_num_heads=12,
284
+ thermal_drop_path=0.0,
285
+ imu_embed_dim=512,
286
+ imu_num_blocks=6,
287
+ imu_num_heads=8,
288
+ imu_drop_path=0.7,
289
+ ):
290
+ def instantiate_trunk(
291
+ embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
292
+ ):
293
+ return SimpleTransformer(
294
+ embed_dim=embed_dim,
295
+ num_blocks=num_blocks,
296
+ ffn_dropout_rate=0.0,
297
+ drop_path_rate=drop_path,
298
+ attn_target=partial(
299
+ MultiheadAttention,
300
+ embed_dim=embed_dim,
301
+ num_heads=num_heads,
302
+ bias=True,
303
+ add_bias_kv=add_bias_kv,
304
+ ),
305
+ pre_transformer_layer=nn.Sequential(
306
+ nn.LayerNorm(embed_dim, eps=1e-6)
307
+ if pre_transformer_ln
308
+ else nn.Identity(),
309
+ EinOpsRearrange("b l d -> l b d"),
310
+ ),
311
+ post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
312
+ )
313
+
314
+ modality_trunks = {}
315
+ modality_trunks[ModalityType.VISION] = instantiate_trunk(
316
+ vision_embed_dim,
317
+ vision_num_blocks,
318
+ vision_num_heads,
319
+ pre_transformer_ln=True,
320
+ add_bias_kv=False,
321
+ drop_path=0.0,
322
+ )
323
+ modality_trunks[ModalityType.TEXT] = instantiate_trunk(
324
+ text_embed_dim,
325
+ text_num_blocks,
326
+ text_num_heads,
327
+ pre_transformer_ln=False,
328
+ add_bias_kv=False,
329
+ drop_path=0.0,
330
+ )
331
+ modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
332
+ audio_embed_dim,
333
+ audio_num_blocks,
334
+ audio_num_heads,
335
+ pre_transformer_ln=False,
336
+ add_bias_kv=True,
337
+ drop_path=audio_drop_path,
338
+ )
339
+ modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
340
+ depth_embed_dim,
341
+ depth_num_blocks,
342
+ depth_num_heads,
343
+ pre_transformer_ln=False,
344
+ add_bias_kv=True,
345
+ drop_path=depth_drop_path,
346
+ )
347
+ modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
348
+ thermal_embed_dim,
349
+ thermal_num_blocks,
350
+ thermal_num_heads,
351
+ pre_transformer_ln=False,
352
+ add_bias_kv=True,
353
+ drop_path=thermal_drop_path,
354
+ )
355
+ modality_trunks[ModalityType.IMU] = instantiate_trunk(
356
+ imu_embed_dim,
357
+ imu_num_blocks,
358
+ imu_num_heads,
359
+ pre_transformer_ln=False,
360
+ add_bias_kv=True,
361
+ drop_path=imu_drop_path,
362
+ )
363
+
364
+ return nn.ModuleDict(modality_trunks)
365
+
366
+ def _create_modality_heads(
367
+ self,
368
+ out_embed_dim,
369
+ vision_embed_dim,
370
+ text_embed_dim,
371
+ audio_embed_dim,
372
+ depth_embed_dim,
373
+ thermal_embed_dim,
374
+ imu_embed_dim,
375
+ ):
376
+ modality_heads = {}
377
+
378
+ modality_heads[ModalityType.VISION] = nn.Sequential(
379
+ nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
380
+ SelectElement(index=0),
381
+ nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
382
+ )
383
+
384
+ modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
385
+ proj=nn.Sequential(
386
+ nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
387
+ nn.Linear(text_embed_dim, out_embed_dim, bias=False),
388
+ )
389
+ )
390
+
391
+ modality_heads[ModalityType.AUDIO] = nn.Sequential(
392
+ nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
393
+ SelectElement(index=0),
394
+ nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
395
+ )
396
+
397
+ modality_heads[ModalityType.DEPTH] = nn.Sequential(
398
+ nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
399
+ SelectElement(index=0),
400
+ nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
401
+ )
402
+
403
+ modality_heads[ModalityType.THERMAL] = nn.Sequential(
404
+ nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
405
+ SelectElement(index=0),
406
+ nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
407
+ )
408
+
409
+ modality_heads[ModalityType.IMU] = nn.Sequential(
410
+ nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
411
+ SelectElement(index=0),
412
+ nn.Dropout(p=0.5),
413
+ nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
414
+ )
415
+
416
+ return nn.ModuleDict(modality_heads)
417
+
418
+ def _create_modality_postprocessors(self, out_embed_dim):
419
+ modality_postprocessors = {}
420
+
421
+ modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
422
+ modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
423
+ Normalize(dim=-1), LearnableLogitScaling(learnable=True)
424
+ )
425
+ modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
426
+ Normalize(dim=-1),
427
+ LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
428
+ )
429
+ modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
430
+ Normalize(dim=-1),
431
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
432
+ )
433
+ modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
434
+ Normalize(dim=-1),
435
+ LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
436
+ )
437
+ modality_postprocessors[ModalityType.IMU] = nn.Sequential(
438
+ Normalize(dim=-1),
439
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
440
+ )
441
+
442
+ return nn.ModuleDict(modality_postprocessors)
443
+
444
+ def forward(self, inputs):
445
+ outputs = {}
446
+ for modality_key, modality_value in inputs.items():
447
+ reduce_list = (
448
+ modality_value.ndim >= 5
449
+ ) # Audio and Video inputs consist of multiple clips
450
+ if reduce_list:
451
+ B, S = modality_value.shape[:2]
452
+ modality_value = modality_value.reshape(
453
+ B * S, *modality_value.shape[2:]
454
+ )
455
+
456
+ if modality_value is not None:
457
+ modality_value = self.modality_preprocessors[modality_key](
458
+ **{modality_key: modality_value}
459
+ )
460
+ trunk_inputs = modality_value["trunk"]
461
+ head_inputs = modality_value["head"]
462
+ modality_value = self.modality_trunks[modality_key](**trunk_inputs)
463
+ modality_value = self.modality_heads[modality_key](
464
+ modality_value, **head_inputs
465
+ )
466
+ modality_value = self.modality_postprocessors[modality_key](
467
+ modality_value
468
+ )
469
+
470
+ if reduce_list:
471
+ modality_value = modality_value.reshape(B, S, -1)
472
+ modality_value = modality_value.mean(dim=1)
473
+
474
+ outputs[modality_key] = modality_value
475
+
476
+ return outputs
477
+
478
+
479
+ def imagebind_huge(pretrained=False):
480
+ model = ImageBindModel(
481
+ vision_embed_dim=1280,
482
+ vision_num_blocks=32,
483
+ vision_num_heads=16,
484
+ text_embed_dim=1024,
485
+ text_num_blocks=24,
486
+ text_num_heads=16,
487
+ out_embed_dim=1024,
488
+ audio_drop_path=0.1,
489
+ imu_drop_path=0.7,
490
+ )
491
+
492
+ if pretrained:
493
+ if not os.path.exists(".checkpoints/imagebind_huge.pth"):
494
+ print(
495
+ "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
496
+ )
497
+ os.makedirs(".checkpoints", exist_ok=True)
498
+ torch.hub.download_url_to_file(
499
+ "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
500
+ ".checkpoints/imagebind_huge.pth",
501
+ progress=True,
502
+ )
503
+
504
+ model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"))
505
+
506
+ return model
imagebind/models/multimodal_preprocessors.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import gzip
9
+ import html
10
+ import io
11
+ import math
12
+ from functools import lru_cache
13
+ from typing import Callable, List, Optional, Tuple
14
+
15
+ import ftfy
16
+ import numpy as np
17
+ import regex as re
18
+ import torch
19
+ import torch.nn as nn
20
+ from iopath.common.file_io import g_pathmgr
21
+ from timm.models.layers import trunc_normal_
22
+
23
+ from imagebind.models.helpers import VerboseNNModule, cast_if_src_dtype
24
+
25
+
26
+ def get_sinusoid_encoding_table(n_position, d_hid):
27
+ """Sinusoid position encoding table"""
28
+
29
+ # TODO: make it with torch instead of numpy
30
+ def get_position_angle_vec(position):
31
+ return [
32
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
33
+ for hid_j in range(d_hid)
34
+ ]
35
+
36
+ sinusoid_table = np.array(
37
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
38
+ )
39
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
40
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
41
+
42
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
43
+
44
+
45
+ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
46
+ N = pos_embed.shape[1]
47
+ if N == target_spatial_size:
48
+ return pos_embed
49
+ dim = pos_embed.shape[-1]
50
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
51
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
52
+ pos_embed = nn.functional.interpolate(
53
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
54
+ 0, 3, 1, 2
55
+ ),
56
+ scale_factor=math.sqrt(target_spatial_size / N),
57
+ mode="bicubic",
58
+ )
59
+ if updated:
60
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
61
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
62
+ return pos_embed
63
+
64
+
65
+ def interpolate_pos_encoding(
66
+ npatch_per_img,
67
+ pos_embed,
68
+ patches_layout,
69
+ input_shape=None,
70
+ first_patch_idx=1,
71
+ ):
72
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
73
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
74
+ if npatch_per_img == N:
75
+ return pos_embed
76
+
77
+ assert (
78
+ patches_layout[-1] == patches_layout[-2]
79
+ ), "Interpolation of pos embed not supported for non-square layouts"
80
+
81
+ class_emb = pos_embed[:, :first_patch_idx]
82
+ pos_embed = pos_embed[:, first_patch_idx:]
83
+
84
+ if input_shape is None or patches_layout[0] == 1:
85
+ # simple 2D pos embedding, no temporal component
86
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
87
+ elif patches_layout[0] > 1:
88
+ # pos embed has a temporal component
89
+ assert len(input_shape) == 4, "temporal interpolation not supported"
90
+ # we only support 2D interpolation in this case
91
+ num_frames = patches_layout[0]
92
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
93
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
94
+ # interpolate embedding for zeroth frame
95
+ pos_embed = interpolate_pos_encoding_2d(
96
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
97
+ )
98
+ else:
99
+ raise ValueError("This type of interpolation isn't implemented")
100
+
101
+ return torch.cat((class_emb, pos_embed), dim=1)
102
+
103
+
104
+ def _get_pos_embedding(
105
+ npatch_per_img,
106
+ pos_embed,
107
+ patches_layout,
108
+ input_shape,
109
+ first_patch_idx=1,
110
+ ):
111
+ pos_embed = interpolate_pos_encoding(
112
+ npatch_per_img,
113
+ pos_embed,
114
+ patches_layout,
115
+ input_shape=input_shape,
116
+ first_patch_idx=first_patch_idx,
117
+ )
118
+ return pos_embed
119
+
120
+
121
+ class PatchEmbedGeneric(nn.Module):
122
+ """
123
+ PatchEmbed from Hydra
124
+ """
125
+
126
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
127
+ super().__init__()
128
+
129
+ if len(proj_stem) > 1:
130
+ self.proj = nn.Sequential(*proj_stem)
131
+ else:
132
+ # Special case to be able to load pre-trained models that were
133
+ # trained with a standard stem
134
+ self.proj = proj_stem[0]
135
+ self.norm_layer = norm_layer
136
+
137
+ def get_patch_layout(self, img_size):
138
+ with torch.no_grad():
139
+ dummy_img = torch.zeros(
140
+ [
141
+ 1,
142
+ ]
143
+ + img_size
144
+ )
145
+ dummy_out = self.proj(dummy_img)
146
+ embed_dim = dummy_out.shape[1]
147
+ patches_layout = tuple(dummy_out.shape[2:])
148
+ num_patches = np.prod(patches_layout)
149
+ return patches_layout, num_patches, embed_dim
150
+
151
+ def forward(self, x):
152
+ x = self.proj(x)
153
+ # B C (T) H W -> B (T)HW C
154
+ x = x.flatten(2).transpose(1, 2)
155
+ if self.norm_layer is not None:
156
+ x = self.norm_layer(x)
157
+ return x
158
+
159
+
160
+ class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
161
+ def __init__(
162
+ self,
163
+ patches_layout: List,
164
+ num_patches: int,
165
+ num_cls_tokens: int,
166
+ embed_dim: int,
167
+ learnable: bool,
168
+ ) -> None:
169
+ super().__init__()
170
+ self.num_cls_tokens = num_cls_tokens
171
+ self.patches_layout = patches_layout
172
+ self.num_patches = num_patches
173
+ self.num_tokens = num_cls_tokens + num_patches
174
+ self.learnable = learnable
175
+ if self.learnable:
176
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
177
+ trunc_normal_(self.pos_embed, std=0.02)
178
+ else:
179
+ self.register_buffer(
180
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
181
+ )
182
+
183
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
184
+ input_shape = vision_input.shape
185
+ pos_embed = _get_pos_embedding(
186
+ all_vision_tokens.size(1) - self.num_cls_tokens,
187
+ pos_embed=self.pos_embed,
188
+ patches_layout=self.patches_layout,
189
+ input_shape=input_shape,
190
+ first_patch_idx=self.num_cls_tokens,
191
+ )
192
+ return pos_embed
193
+
194
+
195
+ class RGBDTPreprocessor(VerboseNNModule):
196
+ def __init__(
197
+ self,
198
+ rgbt_stem: PatchEmbedGeneric,
199
+ depth_stem: Optional[PatchEmbedGeneric],
200
+ img_size: Tuple = (3, 224, 224),
201
+ num_cls_tokens: int = 1,
202
+ pos_embed_fn: Optional[Callable] = None,
203
+ use_type_embed: bool = False,
204
+ init_param_style: str = "openclip",
205
+ ) -> None:
206
+ super().__init__()
207
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
208
+ (
209
+ self.patches_layout,
210
+ self.num_patches,
211
+ self.embed_dim,
212
+ ) = stem.get_patch_layout(img_size)
213
+ self.rgbt_stem = rgbt_stem
214
+ self.depth_stem = depth_stem
215
+ self.use_pos_embed = pos_embed_fn is not None
216
+ self.use_type_embed = use_type_embed
217
+ self.num_cls_tokens = num_cls_tokens
218
+
219
+ if self.use_pos_embed:
220
+ self.pos_embedding_helper = pos_embed_fn(
221
+ patches_layout=self.patches_layout,
222
+ num_cls_tokens=num_cls_tokens,
223
+ num_patches=self.num_patches,
224
+ embed_dim=self.embed_dim,
225
+ )
226
+ if self.num_cls_tokens > 0:
227
+ self.cls_token = nn.Parameter(
228
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
229
+ )
230
+ if self.use_type_embed:
231
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
232
+
233
+ self.init_parameters(init_param_style)
234
+
235
+ @torch.no_grad()
236
+ def init_parameters(self, init_param_style):
237
+ if init_param_style == "openclip":
238
+ # OpenCLIP style initialization
239
+ scale = self.embed_dim**-0.5
240
+ if self.use_pos_embed:
241
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
242
+ self.pos_embedding_helper.pos_embed *= scale
243
+
244
+ if self.num_cls_tokens > 0:
245
+ nn.init.normal_(self.cls_token)
246
+ self.cls_token *= scale
247
+ elif init_param_style == "vit":
248
+ self.cls_token.data.fill_(0)
249
+ else:
250
+ raise ValueError(f"Unknown init {init_param_style}")
251
+
252
+ if self.use_type_embed:
253
+ nn.init.normal_(self.type_embed)
254
+
255
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
256
+ # tokens is of shape B x L x D
257
+ tokens = stem(input)
258
+ assert tokens.ndim == 3
259
+ assert tokens.shape[2] == self.embed_dim
260
+ B = tokens.shape[0]
261
+ if self.num_cls_tokens > 0:
262
+ class_tokens = self.cls_token.expand(
263
+ B, -1, -1
264
+ ) # stole class_tokens impl from Phil Wang, thanks
265
+ tokens = torch.cat((class_tokens, tokens), dim=1)
266
+ if self.use_pos_embed:
267
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
268
+ tokens = tokens + pos_embed
269
+ if self.use_type_embed:
270
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
271
+ return tokens
272
+
273
+ def forward(self, vision=None, depth=None, patch_mask=None):
274
+ if patch_mask is not None:
275
+ raise NotImplementedError()
276
+
277
+ if vision is not None:
278
+ vision_tokens = self.tokenize_input_and_cls_pos(
279
+ vision, self.rgbt_stem, patch_mask
280
+ )
281
+
282
+ if depth is not None:
283
+ depth_tokens = self.tokenize_input_and_cls_pos(
284
+ depth, self.depth_stem, patch_mask
285
+ )
286
+
287
+ # aggregate tokens
288
+ if vision is not None and depth is not None:
289
+ final_tokens = vision_tokens + depth_tokens
290
+ else:
291
+ final_tokens = vision_tokens if vision is not None else depth_tokens
292
+ return_dict = {
293
+ "trunk": {
294
+ "tokens": final_tokens,
295
+ },
296
+ "head": {},
297
+ }
298
+ return return_dict
299
+
300
+
301
+ class AudioPreprocessor(RGBDTPreprocessor):
302
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
303
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
304
+
305
+ def forward(self, audio=None):
306
+ return super().forward(vision=audio)
307
+
308
+
309
+ class ThermalPreprocessor(RGBDTPreprocessor):
310
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
311
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
312
+
313
+ def forward(self, thermal=None):
314
+ return super().forward(vision=thermal)
315
+
316
+
317
+ def build_causal_attention_mask(context_length):
318
+ # lazily create causal attention mask, with full attention between the vision tokens
319
+ # pytorch uses additive attention mask; fill with -inf
320
+ mask = torch.empty(context_length, context_length, requires_grad=False)
321
+ mask.fill_(float("-inf"))
322
+ mask.triu_(1) # zero out the lower diagonal
323
+ return mask
324
+
325
+
326
+ class TextPreprocessor(VerboseNNModule):
327
+ def __init__(
328
+ self,
329
+ vocab_size: int,
330
+ context_length: int,
331
+ embed_dim: int,
332
+ causal_masking: bool,
333
+ supply_seq_len_to_head: bool = True,
334
+ num_cls_tokens: int = 0,
335
+ init_param_style: str = "openclip",
336
+ ) -> None:
337
+ super().__init__()
338
+ self.vocab_size = vocab_size
339
+ self.context_length = context_length
340
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
341
+ self.pos_embed = nn.Parameter(
342
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
343
+ )
344
+ self.causal_masking = causal_masking
345
+ if self.causal_masking:
346
+ mask = build_causal_attention_mask(self.context_length)
347
+ # register the mask as a buffer so it can be moved to the right device
348
+ self.register_buffer("mask", mask)
349
+
350
+ self.supply_seq_len_to_head = supply_seq_len_to_head
351
+ self.num_cls_tokens = num_cls_tokens
352
+ self.embed_dim = embed_dim
353
+ if num_cls_tokens > 0:
354
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
355
+ self.cls_token = nn.Parameter(
356
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
357
+ )
358
+
359
+ self.init_parameters(init_param_style)
360
+
361
+ @torch.no_grad()
362
+ def init_parameters(self, init_param_style="openclip"):
363
+ # OpenCLIP style initialization
364
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
365
+ nn.init.normal_(self.pos_embed, std=0.01)
366
+
367
+ if init_param_style == "openclip":
368
+ # OpenCLIP style initialization
369
+ scale = self.embed_dim**-0.5
370
+ if self.num_cls_tokens > 0:
371
+ nn.init.normal_(self.cls_token)
372
+ self.cls_token *= scale
373
+ elif init_param_style == "vit":
374
+ self.cls_token.data.fill_(0)
375
+ else:
376
+ raise ValueError(f"Unknown init {init_param_style}")
377
+
378
+ def forward(self, text):
379
+ # text tokens are of shape B x L x D
380
+ text_tokens = self.token_embedding(text)
381
+ # concat CLS tokens if any
382
+ if self.num_cls_tokens > 0:
383
+ B = text_tokens.shape[0]
384
+ class_tokens = self.cls_token.expand(
385
+ B, -1, -1
386
+ ) # stole class_tokens impl from Phil Wang, thanks
387
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
388
+ text_tokens = text_tokens + self.pos_embed
389
+ return_dict = {
390
+ "trunk": {
391
+ "tokens": text_tokens,
392
+ },
393
+ "head": {},
394
+ }
395
+ # Compute sequence length after adding CLS tokens
396
+ if self.supply_seq_len_to_head:
397
+ text_lengths = text.argmax(dim=-1)
398
+ return_dict["head"] = {
399
+ "seq_len": text_lengths,
400
+ }
401
+ if self.causal_masking:
402
+ return_dict["trunk"].update({"attn_mask": self.mask})
403
+ return return_dict
404
+
405
+
406
+ class Im2Video(nn.Module):
407
+ """Convert an image into a trivial video."""
408
+
409
+ def __init__(self, time_dim=2):
410
+ super().__init__()
411
+ self.time_dim = time_dim
412
+
413
+ def forward(self, x):
414
+ if x.ndim == 4:
415
+ # B, C, H, W -> B, C, T, H, W
416
+ return x.unsqueeze(self.time_dim)
417
+ elif x.ndim == 5:
418
+ return x
419
+ else:
420
+ raise ValueError(f"Dimension incorrect {x.shape}")
421
+
422
+
423
+ class PadIm2Video(Im2Video):
424
+ def __init__(self, ntimes, pad_type, time_dim=2):
425
+ super().__init__(time_dim=time_dim)
426
+ assert ntimes > 0
427
+ assert pad_type in ["zero", "repeat"]
428
+ self.ntimes = ntimes
429
+ self.pad_type = pad_type
430
+
431
+ def forward(self, x):
432
+ x = super().forward(x)
433
+ if x.shape[self.time_dim] == 1:
434
+ if self.pad_type == "repeat":
435
+ new_shape = [1] * len(x.shape)
436
+ new_shape[self.time_dim] = self.ntimes
437
+ x = x.repeat(new_shape)
438
+ elif self.pad_type == "zero":
439
+ padarg = [0, 0] * len(x.shape)
440
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
441
+ x = nn.functional.pad(x, padarg)
442
+ return x
443
+
444
+
445
+ # Modified from github.com/openai/CLIP
446
+ @lru_cache()
447
+ def bytes_to_unicode():
448
+ """
449
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
450
+ The reversible bpe codes work on unicode strings.
451
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
452
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
453
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
454
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
455
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
456
+ """
457
+ bs = (
458
+ list(range(ord("!"), ord("~") + 1))
459
+ + list(range(ord("¡"), ord("¬") + 1))
460
+ + list(range(ord("®"), ord("ÿ") + 1))
461
+ )
462
+ cs = bs[:]
463
+ n = 0
464
+ for b in range(2**8):
465
+ if b not in bs:
466
+ bs.append(b)
467
+ cs.append(2**8 + n)
468
+ n += 1
469
+ cs = [chr(n) for n in cs]
470
+ return dict(zip(bs, cs))
471
+
472
+
473
+ def get_pairs(word):
474
+ """Return set of symbol pairs in a word.
475
+ Word is represented as tuple of symbols (symbols being variable-length strings).
476
+ """
477
+ pairs = set()
478
+ prev_char = word[0]
479
+ for char in word[1:]:
480
+ pairs.add((prev_char, char))
481
+ prev_char = char
482
+ return pairs
483
+
484
+
485
+ def basic_clean(text):
486
+ text = ftfy.fix_text(text)
487
+ text = html.unescape(html.unescape(text))
488
+ return text.strip()
489
+
490
+
491
+ def whitespace_clean(text):
492
+ text = re.sub(r"\s+", " ", text)
493
+ text = text.strip()
494
+ return text
495
+
496
+
497
+ class SimpleTokenizer(object):
498
+ def __init__(self, bpe_path: str, context_length=77):
499
+ self.byte_encoder = bytes_to_unicode()
500
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
501
+
502
+ with g_pathmgr.open(bpe_path, "rb") as fh:
503
+ bpe_bytes = io.BytesIO(fh.read())
504
+ merges: List[str] = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
505
+ merges = merges[1 : 49152 - 256 - 2 + 1]
506
+ merges: List[Tuple[str, ...]] = [tuple(merge.split()) for merge in merges]
507
+ vocab = list(bytes_to_unicode().values())
508
+ vocab = vocab + [v + "</w>" for v in vocab]
509
+ for merge in merges:
510
+ vocab.append("".join(merge))
511
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
512
+ self.encoder = dict(zip(vocab, range(len(vocab))))
513
+ self.decoder = {v: k for k, v in self.encoder.items()}
514
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
515
+ self.cache = {
516
+ "<|startoftext|>": "<|startoftext|>",
517
+ "<|endoftext|>": "<|endoftext|>",
518
+ }
519
+ self.pat = re.compile(
520
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
521
+ re.IGNORECASE,
522
+ )
523
+ self.context_length = context_length
524
+
525
+ def bpe(self, token):
526
+ if token in self.cache:
527
+ return self.cache[token]
528
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
529
+ pairs = get_pairs(word)
530
+
531
+ if not pairs:
532
+ return token + "</w>"
533
+
534
+ while True:
535
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
536
+ if bigram not in self.bpe_ranks:
537
+ break
538
+ first, second = bigram
539
+ new_word = []
540
+ i = 0
541
+ while i < len(word):
542
+ try:
543
+ j = word.index(first, i)
544
+ new_word.extend(word[i:j])
545
+ i = j
546
+ except:
547
+ new_word.extend(word[i:])
548
+ break
549
+
550
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
551
+ new_word.append(first + second)
552
+ i += 2
553
+ else:
554
+ new_word.append(word[i])
555
+ i += 1
556
+ new_word = tuple(new_word)
557
+ word = new_word
558
+ if len(word) == 1:
559
+ break
560
+ else:
561
+ pairs = get_pairs(word)
562
+ word = " ".join(word)
563
+ self.cache[token] = word
564
+ return word
565
+
566
+ def encode(self, text):
567
+ bpe_tokens = []
568
+ text = whitespace_clean(basic_clean(text)).lower()
569
+ for token in re.findall(self.pat, text):
570
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
571
+ bpe_tokens.extend(
572
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
573
+ )
574
+ return bpe_tokens
575
+
576
+ def decode(self, tokens):
577
+ text = "".join([self.decoder[token] for token in tokens])
578
+ text = (
579
+ bytearray([self.byte_decoder[c] for c in text])
580
+ .decode("utf-8", errors="replace")
581
+ .replace("</w>", " ")
582
+ )
583
+ return text
584
+
585
+ def __call__(self, texts, context_length=None):
586
+ if not context_length:
587
+ context_length = self.context_length
588
+
589
+ if isinstance(texts, str):
590
+ texts = [texts]
591
+
592
+ sot_token = self.encoder["<|startoftext|>"]
593
+ eot_token = self.encoder["<|endoftext|>"]
594
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
595
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
596
+
597
+ for i, tokens in enumerate(all_tokens):
598
+ tokens = tokens[:context_length]
599
+ result[i, : len(tokens)] = torch.tensor(tokens)
600
+
601
+ if len(result) == 1:
602
+ return result[0]
603
+ return result
604
+
605
+
606
+ class IMUPreprocessor(VerboseNNModule):
607
+ def __init__(
608
+ self,
609
+ kernel_size: int,
610
+ imu_stem: PatchEmbedGeneric,
611
+ embed_dim: int,
612
+ img_size: Tuple = (6, 2000),
613
+ num_cls_tokens: int = 1,
614
+ pos_embed_fn: Optional[Callable] = None,
615
+ init_param_style: str = "openclip",
616
+ ) -> None:
617
+ super().__init__()
618
+ self.imu_stem = imu_stem
619
+ self.embed_dim = embed_dim
620
+ self.use_pos_embed = pos_embed_fn is not None
621
+ self.num_cls_tokens = num_cls_tokens
622
+ self.kernel_size = kernel_size
623
+ self.pos_embed = nn.Parameter(
624
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
625
+ )
626
+
627
+ if self.num_cls_tokens > 0:
628
+ self.cls_token = nn.Parameter(
629
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
630
+ )
631
+
632
+ self.init_parameters(init_param_style)
633
+
634
+ @torch.no_grad()
635
+ def init_parameters(self, init_param_style):
636
+ nn.init.normal_(self.pos_embed, std=0.01)
637
+
638
+ if init_param_style == "openclip":
639
+ # OpenCLIP style initialization
640
+ scale = self.embed_dim**-0.5
641
+
642
+ if self.num_cls_tokens > 0:
643
+ nn.init.normal_(self.cls_token)
644
+ self.cls_token *= scale
645
+ elif init_param_style == "vit":
646
+ self.cls_token.data.fill_(0)
647
+ else:
648
+ raise ValueError(f"Unknown init {init_param_style}")
649
+
650
+ def tokenize_input_and_cls_pos(self, input, stem):
651
+ # tokens is of shape B x L x D
652
+ tokens = stem.norm_layer(stem.proj(input))
653
+ assert tokens.ndim == 3
654
+ assert tokens.shape[2] == self.embed_dim
655
+ B = tokens.shape[0]
656
+ if self.num_cls_tokens > 0:
657
+ class_tokens = self.cls_token.expand(
658
+ B, -1, -1
659
+ ) # stole class_tokens impl from Phil Wang, thanks
660
+ tokens = torch.cat((class_tokens, tokens), dim=1)
661
+ if self.use_pos_embed:
662
+ tokens = tokens + self.pos_embed
663
+ return tokens
664
+
665
+ def forward(self, imu):
666
+ # Patchify
667
+ imu = imu.unfold(
668
+ -1,
669
+ self.kernel_size,
670
+ self.kernel_size,
671
+ ).permute(0, 2, 1, 3)
672
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
673
+
674
+ imu_tokens = self.tokenize_input_and_cls_pos(
675
+ imu,
676
+ self.imu_stem,
677
+ )
678
+
679
+ return_dict = {
680
+ "trunk": {
681
+ "tokens": imu_tokens,
682
+ },
683
+ "head": {},
684
+ }
685
+ return return_dict
imagebind/models/transformer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # Code modified from
9
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
10
+ # https://github.com/facebookresearch/deit/blob/main/models.py
11
+ # and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
12
+
13
+
14
+ from functools import partial
15
+ from typing import Callable, List, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint as checkpoint
20
+ from timm.models.layers import DropPath, trunc_normal_
21
+
22
+
23
+ class Attention(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim,
27
+ num_heads=8,
28
+ qkv_bias=False,
29
+ qk_scale=None,
30
+ attn_drop=0.0,
31
+ proj_drop=0.0,
32
+ ):
33
+ super().__init__()
34
+ self.num_heads = num_heads
35
+ head_dim = dim // num_heads
36
+ # NOTE scale factor was wrong in my original version,
37
+ # can set manually to be compat with prev weights
38
+ self.scale = qk_scale or head_dim**-0.5
39
+
40
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
41
+ self.attn_drop = nn.Dropout(attn_drop)
42
+ self.proj = nn.Linear(dim, dim)
43
+ self.proj_drop = nn.Dropout(proj_drop)
44
+
45
+ def forward(self, x):
46
+ B, N, C = x.shape
47
+ qkv = (
48
+ self.qkv(x)
49
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
50
+ .permute(2, 0, 3, 1, 4)
51
+ )
52
+ q, k, v = (
53
+ qkv[0],
54
+ qkv[1],
55
+ qkv[2],
56
+ ) # make torchscript happy (cannot use tensor as tuple)
57
+
58
+ attn = (q @ k.transpose(-2, -1)) * self.scale
59
+ attn = attn.softmax(dim=-1)
60
+ attn = self.attn_drop(attn)
61
+
62
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
63
+ x = self.proj(x)
64
+ x = self.proj_drop(x)
65
+ return x
66
+
67
+
68
+ class Mlp(nn.Module):
69
+ def __init__(
70
+ self,
71
+ in_features,
72
+ hidden_features=None,
73
+ out_features=None,
74
+ act_layer=nn.GELU,
75
+ drop=0.0,
76
+ ):
77
+ super().__init__()
78
+ out_features = out_features or in_features
79
+ hidden_features = hidden_features or in_features
80
+ self.fc1 = nn.Linear(in_features, hidden_features)
81
+ self.act = act_layer()
82
+ self.fc2 = nn.Linear(hidden_features, out_features)
83
+ self.drop = nn.Dropout(drop)
84
+
85
+ def forward(self, x):
86
+ x = self.fc1(x)
87
+ x = self.act(x)
88
+ x = self.drop(x)
89
+ x = self.fc2(x)
90
+ x = self.drop(x)
91
+ return x
92
+
93
+
94
+ class MultiheadAttention(nn.MultiheadAttention):
95
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
96
+ return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
97
+
98
+
99
+ class ViTAttention(Attention):
100
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
101
+ assert attn_mask is None
102
+ return super().forward(x)
103
+
104
+
105
+ class BlockWithMasking(nn.Module):
106
+ def __init__(
107
+ self,
108
+ dim: int,
109
+ attn_target: Callable,
110
+ mlp_ratio: int = 4,
111
+ act_layer: Callable = nn.GELU,
112
+ norm_layer: Callable = nn.LayerNorm,
113
+ ffn_dropout_rate: float = 0.0,
114
+ drop_path: float = 0.0,
115
+ layer_scale_type: Optional[str] = None,
116
+ layer_scale_init_value: float = 1e-4,
117
+ ):
118
+ super().__init__()
119
+
120
+ assert not isinstance(
121
+ attn_target, nn.Module
122
+ ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
123
+ self.attn = attn_target()
124
+ if drop_path > 0.0:
125
+ self.drop_path = DropPath(drop_path)
126
+ else:
127
+ self.drop_path = nn.Identity()
128
+ self.norm_1 = norm_layer(dim)
129
+ mlp_hidden_dim = int(mlp_ratio * dim)
130
+ self.mlp = Mlp(
131
+ in_features=dim,
132
+ hidden_features=mlp_hidden_dim,
133
+ act_layer=act_layer,
134
+ drop=ffn_dropout_rate,
135
+ )
136
+ self.norm_2 = norm_layer(dim)
137
+ self.layer_scale_type = layer_scale_type
138
+ if self.layer_scale_type is not None:
139
+ assert self.layer_scale_type in [
140
+ "per_channel",
141
+ "scalar",
142
+ ], f"Found Layer scale type {self.layer_scale_type}"
143
+ if self.layer_scale_type == "per_channel":
144
+ # one gamma value per channel
145
+ gamma_shape = [1, 1, dim]
146
+ elif self.layer_scale_type == "scalar":
147
+ # single gamma value for all channels
148
+ gamma_shape = [1, 1, 1]
149
+ # two gammas: for each part of the fwd in the encoder
150
+ self.layer_scale_gamma1 = nn.Parameter(
151
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
152
+ requires_grad=True,
153
+ )
154
+ self.layer_scale_gamma2 = nn.Parameter(
155
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
156
+ requires_grad=True,
157
+ )
158
+
159
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
160
+ if self.layer_scale_type is None:
161
+ x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
162
+ x = x + self.drop_path(self.mlp(self.norm_2(x)))
163
+ else:
164
+ x = (
165
+ x
166
+ + self.drop_path(self.attn(self.norm_1(x), attn_mask))
167
+ * self.layer_scale_gamma1
168
+ )
169
+ x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
170
+ return x
171
+
172
+
173
+ _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
174
+
175
+
176
+ class SimpleTransformer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ attn_target: Callable,
180
+ embed_dim: int,
181
+ num_blocks: int,
182
+ block: Callable = BlockWithMasking,
183
+ pre_transformer_layer: Optional[Callable] = None,
184
+ post_transformer_layer: Optional[Callable] = None,
185
+ drop_path_rate: float = 0.0,
186
+ drop_path_type: str = "progressive",
187
+ norm_layer: Callable = _LAYER_NORM,
188
+ mlp_ratio: int = 4,
189
+ ffn_dropout_rate: float = 0.0,
190
+ layer_scale_type: Optional[str] = None, # from cait; possible values are None, "per_channel", "scalar"
191
+ layer_scale_init_value: float = 1e-4, # from cait; float
192
+ weight_init_style: str = "jax", # possible values jax or pytorch
193
+ ):
194
+ """
195
+ Simple Transformer with the following features
196
+ 1. Supports masked attention
197
+ 2. Supports DropPath
198
+ 3. Supports LayerScale
199
+ 4. Supports Dropout in Attention and FFN
200
+ 5. Makes few assumptions about the input except that it is a Tensor
201
+ """
202
+ super().__init__()
203
+ self.pre_transformer_layer = pre_transformer_layer
204
+ if drop_path_type == "progressive":
205
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
206
+ elif drop_path_type == "uniform":
207
+ dpr = [drop_path_rate for i in range(num_blocks)]
208
+ else:
209
+ raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
210
+
211
+ self.blocks = nn.Sequential(
212
+ *[
213
+ block(
214
+ dim=embed_dim,
215
+ attn_target=attn_target,
216
+ mlp_ratio=mlp_ratio,
217
+ ffn_dropout_rate=ffn_dropout_rate,
218
+ drop_path=dpr[i],
219
+ norm_layer=norm_layer,
220
+ layer_scale_type=layer_scale_type,
221
+ layer_scale_init_value=layer_scale_init_value,
222
+ )
223
+ for i in range(num_blocks)
224
+ ]
225
+ )
226
+ self.post_transformer_layer = post_transformer_layer
227
+ self.weight_init_style = weight_init_style
228
+ self.apply(self._init_weights)
229
+
230
+ def _init_weights(self, m):
231
+ if isinstance(m, nn.Linear):
232
+ if self.weight_init_style == "jax":
233
+ # Based on MAE and official Jax ViT implementation
234
+ torch.nn.init.xavier_uniform_(m.weight)
235
+ elif self.weight_init_style == "pytorch":
236
+ # PyTorch ViT uses trunc_normal_
237
+ trunc_normal_(m.weight, std=0.02)
238
+
239
+ if m.bias is not None:
240
+ nn.init.constant_(m.bias, 0)
241
+ elif isinstance(m, (nn.LayerNorm)):
242
+ nn.init.constant_(m.bias, 0)
243
+ nn.init.constant_(m.weight, 1.0)
244
+
245
+ def forward(
246
+ self,
247
+ tokens: torch.Tensor,
248
+ attn_mask: torch.Tensor = None,
249
+ use_checkpoint: bool = False,
250
+ checkpoint_every_n: int = 1,
251
+ checkpoint_blk_ids: Optional[List[int]] = None,
252
+ ):
253
+ """
254
+ Inputs
255
+ - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
256
+ - attn: mask of shape L x L
257
+
258
+ Output
259
+ - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
260
+ """
261
+ if self.pre_transformer_layer:
262
+ tokens = self.pre_transformer_layer(tokens)
263
+ if use_checkpoint and checkpoint_blk_ids is None:
264
+ checkpoint_blk_ids = [
265
+ blk_id
266
+ for blk_id in range(len(self.blocks))
267
+ if blk_id % checkpoint_every_n == 0
268
+ ]
269
+ if checkpoint_blk_ids:
270
+ checkpoint_blk_ids = set(checkpoint_blk_ids)
271
+ for blk_id, blk in enumerate(self.blocks):
272
+ if use_checkpoint and blk_id in checkpoint_blk_ids:
273
+ tokens = checkpoint.checkpoint(
274
+ blk, tokens, attn_mask, use_reentrant=False
275
+ )
276
+ else:
277
+ tokens = blk(tokens, attn_mask=attn_mask)
278
+ if self.post_transformer_layer:
279
+ tokens = self.post_transformer_layer(tokens)
280
+ return tokens
pipeline.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.io
2
+ from einops import rearrange, repeat
3
+ import numpy as np
4
+ import inspect
5
+ from typing import List, Optional, Union, Tuple
6
+
7
+ import os
8
+ import PIL
9
+ import torch
10
+ import torchaudio
11
+ import torchvision.io
12
+ import torchvision.transforms as transforms
13
+
14
+ from transformers import ImageProcessingMixin
15
+
16
+ from diffusers.loaders import TextualInversionLoaderMixin
17
+ from diffusers.models import AutoencoderKL
18
+ from diffusers.schedulers import KarrasDiffusionSchedulers, PNDMScheduler
19
+ from diffusers.utils import logging
20
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
21
+ from diffusers.image_processor import VaeImageProcessor
22
+
23
+ from unet import AudioUNet3DConditionModel
24
+ from audio_encoder import ImageBindSegmaskAudioEncoder
25
+ from imagebind.data import waveform2melspec
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ def waveform_to_melspectrogram(
31
+ waveform: Union[np.ndarray, torch.Tensor],
32
+ num_mel_bins=128,
33
+ target_length=204,
34
+ sample_rate=16000,
35
+ clip_duration=2.,
36
+ mean=-4.268,
37
+ std=9.138
38
+ ):
39
+ if isinstance(waveform, np.ndarray):
40
+ waveform = torch.from_numpy(waveform)
41
+
42
+ audio_length = waveform.shape[1]
43
+ audio_target_length = int(clip_duration * sample_rate)
44
+
45
+ audio_start_idx = 0
46
+ if audio_length > audio_target_length:
47
+ audio_start_idx = (audio_length - audio_target_length) // 2
48
+ audio_end_idx = audio_start_idx + audio_target_length
49
+ waveform_clip = waveform[:, audio_start_idx:audio_end_idx]
50
+
51
+ waveform_melspec = waveform2melspec(
52
+ waveform_clip, sample_rate, num_mel_bins, target_length
53
+ ) # (1, n_mel, n_frame)
54
+
55
+ normalize = transforms.Normalize(mean=mean, std=std)
56
+
57
+ audio_clip = normalize(waveform_melspec)
58
+
59
+ return audio_clip # (1, freq, time)
60
+
61
+
62
+ class AudioMelspectrogramExtractor(ImageProcessingMixin):
63
+
64
+ def __init__(
65
+ self,
66
+ num_mel_bins=128,
67
+ target_length=204,
68
+ sample_rate=16000,
69
+ clip_duration=2,
70
+ mean=-4.268,
71
+ std=9.138
72
+ ):
73
+ super().__init__()
74
+ self.num_mel_bins = num_mel_bins
75
+ self.target_length = target_length
76
+ self.sample_rate = sample_rate
77
+ self.clip_duration = clip_duration
78
+ self.mean = mean
79
+ self.std = std
80
+
81
+ @property
82
+ def max_length_s(self) -> int:
83
+ return self.clip_duration
84
+
85
+ @property
86
+ def sampling_rate(self) -> int:
87
+ return self.sample_rate
88
+
89
+ def __call__(
90
+ self,
91
+ waveforms: Union[
92
+ np.ndarray,
93
+ torch.Tensor,
94
+ List[np.ndarray],
95
+ List[torch.Tensor]
96
+ ]
97
+ ):
98
+ if isinstance(waveforms, (np.ndarray, torch.Tensor)) and waveforms.ndim == 2:
99
+ waveforms = [waveforms, ]
100
+ features = []
101
+
102
+ for waveform in waveforms:
103
+ feature = waveform_to_melspectrogram(
104
+ waveform=waveform,
105
+ num_mel_bins=self.num_mel_bins,
106
+ target_length=self.target_length,
107
+ sample_rate=self.sample_rate,
108
+ clip_duration=self.clip_duration,
109
+ mean=self.mean,
110
+ std=self.std
111
+ )
112
+ features.append(feature)
113
+ features = torch.stack(features, dim=0)
114
+
115
+ return features # (b c n t)
116
+
117
+
118
+ class AudioCondAnimationPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
119
+ """
120
+ Pipeline for text-guided image to image generation using stable unCLIP.
121
+
122
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
123
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
124
+
125
+ Args:
126
+ feature_extractor ([`CLIPImageProcessor`]):
127
+ Feature extractor for image pre-processing before being encoded.
128
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
129
+ scheduler ([`KarrasDiffusionSchedulers`]):
130
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
131
+ vae ([`AutoencoderKL`]):
132
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
133
+ """
134
+ unet: AudioUNet3DConditionModel
135
+ scheduler: KarrasDiffusionSchedulers
136
+ vae: AutoencoderKL
137
+ audio_encoder: ImageBindSegmaskAudioEncoder
138
+
139
+ def __init__(
140
+ self,
141
+ unet: AudioUNet3DConditionModel,
142
+ scheduler: KarrasDiffusionSchedulers,
143
+ vae: AutoencoderKL,
144
+ audio_encoder: ImageBindSegmaskAudioEncoder,
145
+ null_text_encodings_path: str = ""
146
+ ):
147
+ super().__init__()
148
+
149
+ self.register_modules(
150
+ unet=unet,
151
+ scheduler=scheduler,
152
+ vae=vae,
153
+ audio_encoder=audio_encoder
154
+ )
155
+
156
+ if null_text_encodings_path:
157
+ self.null_text_encoding = torch.load(null_text_encodings_path).view(1, 77, 768)
158
+
159
+ self.melspectrogram_shape = (128, 204)
160
+
161
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
162
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
163
+ self.audio_processor = AudioMelspectrogramExtractor()
164
+
165
+ @torch.no_grad()
166
+ def encode_text(
167
+ self,
168
+ text_encodings,
169
+ device,
170
+ dtype,
171
+ do_text_classifier_free_guidance,
172
+ do_audio_classifier_free_guidance,
173
+ ):
174
+ if isinstance(text_encodings, (List, Tuple)):
175
+ text_encodings = torch.cat(text_encodings)
176
+
177
+ text_encodings = text_encodings.to(dtype=dtype, device=device)
178
+ batch_size = len(text_encodings)
179
+
180
+ # get unconditional embeddings for classifier free guidance
181
+ if do_text_classifier_free_guidance:
182
+ if not hasattr(self, "null_text_encoding"):
183
+ uncond_token = ""
184
+
185
+ max_length = text_encodings.shape[1]
186
+ uncond_input = self.tokenizer(
187
+ uncond_token,
188
+ padding="max_length",
189
+ max_length=max_length,
190
+ truncation=True,
191
+ return_tensors="pt",
192
+ )
193
+
194
+ if hasattr(self.text_encoder.config,
195
+ "use_attention_mask") and self.text_encoder.config.use_attention_mask:
196
+ attention_mask = uncond_input.attention_mask.to(device)
197
+ else:
198
+ attention_mask = None
199
+
200
+ uncond_text_encodings = self.text_encoder(
201
+ uncond_input.input_ids.to(device),
202
+ attention_mask=attention_mask,
203
+ )
204
+ uncond_text_encodings = uncond_text_encodings[0]
205
+
206
+ else:
207
+ uncond_text_encodings = self.null_text_encoding
208
+
209
+ uncond_text_encodings = repeat(uncond_text_encodings, "1 n d -> b n d", b=batch_size).contiguous()
210
+ uncond_text_encodings = uncond_text_encodings.to(dtype=dtype, device=device)
211
+
212
+ if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg
213
+ text_encodings = torch.cat([uncond_text_encodings, text_encodings, text_encodings])
214
+ elif do_text_classifier_free_guidance: # only text cfg
215
+ text_encodings = torch.cat([uncond_text_encodings, text_encodings])
216
+ elif do_audio_classifier_free_guidance: # only audio cfg
217
+ text_encodings = torch.cat([text_encodings, text_encodings])
218
+
219
+ return text_encodings
220
+
221
+ @torch.no_grad()
222
+ def encode_audio(
223
+ self,
224
+ audios: Union[List[np.ndarray], List[torch.Tensor]],
225
+ video_length: int = 12,
226
+ do_text_classifier_free_guidance: bool = False,
227
+ do_audio_classifier_free_guidance: bool = False,
228
+ device: torch.device = torch.device("cuda:0"),
229
+ dtype: torch.dtype = torch.float32
230
+ ):
231
+ batch_size = len(audios)
232
+ melspectrograms = self.audio_processor(audios).to(device=device, dtype=dtype) # (b c n t)
233
+
234
+ # audio_encodings: (b, n, c)
235
+ # audio_masks: (b, s, n)
236
+ _, audio_encodings, audio_masks = self.audio_encoder(
237
+ melspectrograms, normalize=False, return_dict=False
238
+ )
239
+ audio_encodings = repeat(audio_encodings, "b n c -> b f n c", f=video_length)
240
+
241
+ if do_audio_classifier_free_guidance:
242
+ null_melspectrograms = torch.zeros(1, 1, *self.melspectrogram_shape).to(device=device, dtype=dtype)
243
+ _, null_audio_encodings, null_audio_masks = self.audio_encoder(
244
+ null_melspectrograms, normalize=False, return_dict=False
245
+ )
246
+ null_audio_encodings = repeat(null_audio_encodings, "1 n c -> b f n c", b=batch_size, f=video_length)
247
+
248
+ if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg
249
+ audio_encodings = torch.cat([null_audio_encodings, null_audio_encodings, audio_encodings])
250
+ audio_masks = torch.cat([null_audio_masks, null_audio_masks, audio_masks])
251
+ elif do_text_classifier_free_guidance: # only text cfg
252
+ audio_encodings = torch.cat([audio_encodings, audio_encodings])
253
+ audio_masks = torch.cat([audio_masks, audio_masks])
254
+ elif do_audio_classifier_free_guidance: # only audio cfg
255
+ audio_encodings = torch.cat([null_audio_encodings, audio_encodings])
256
+ audio_masks = torch.cat([null_audio_masks, audio_masks])
257
+
258
+ return audio_encodings, audio_masks
259
+
260
+ @torch.no_grad()
261
+ def encode_latents(self, image: torch.Tensor):
262
+ dtype = self.vae.dtype
263
+ image = image.to(device=self.device, dtype=dtype)
264
+ image_latents = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
265
+ return image_latents
266
+
267
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
268
+ @torch.no_grad()
269
+ def decode_latents(self, latents):
270
+ dtype = next(self.vae.parameters()).dtype
271
+ latents = latents.to(dtype=dtype)
272
+ latents = 1 / self.vae.config.scaling_factor * latents
273
+ image = self.vae.decode(latents).sample
274
+ image = (image / 2 + 0.5).clamp(0, 1).cpu().float() # ((b t) c h w)
275
+ return image
276
+
277
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
278
+ def prepare_extra_step_kwargs(self, generator, eta):
279
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
280
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
281
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
282
+ # and should be between [0, 1]
283
+
284
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
285
+ extra_step_kwargs = {}
286
+ if accepts_eta:
287
+ extra_step_kwargs["eta"] = eta
288
+
289
+ # check if the scheduler accepts generator
290
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
291
+ if accepts_generator:
292
+ extra_step_kwargs["generator"] = generator
293
+ return extra_step_kwargs
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
296
+ def prepare_video_latents(
297
+ self,
298
+ image_latents: torch.Tensor,
299
+ num_channels_latents: int,
300
+ video_length: int = 12,
301
+ height: int = 256,
302
+ width: int = 256,
303
+ device: torch.device = torch.device("cuda"),
304
+ dtype: torch.dtype = torch.float32,
305
+ generator: Optional[torch.Generator] = None,
306
+ ):
307
+ batch_size = len(image_latents)
308
+ shape = (
309
+ batch_size,
310
+ num_channels_latents,
311
+ video_length - 1,
312
+ height // self.vae_scale_factor,
313
+ width // self.vae_scale_factor
314
+ )
315
+
316
+ image_latents = image_latents.unsqueeze(2) # (b c 1 h w)
317
+ rand_noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
318
+ noise_latents = torch.cat([image_latents, rand_noise], dim=2)
319
+
320
+ # scale the initial noise by the standard deviation required by the scheduler
321
+ noise_latents = noise_latents * self.scheduler.init_noise_sigma
322
+
323
+ return noise_latents
324
+
325
+ @torch.no_grad()
326
+ def __call__(
327
+ self,
328
+ images: List[PIL.Image.Image],
329
+ audios: Union[List[np.ndarray], List[torch.Tensor]],
330
+ text_encodings: List[torch.Tensor],
331
+ video_length: int = 12,
332
+ height: int = 256,
333
+ width: int = 256,
334
+ num_inference_steps: int = 20,
335
+ audio_guidance_scale: float = 4.0,
336
+ text_guidance_scale: float = 1.0,
337
+ generator: Optional[torch.Generator] = None,
338
+ return_dict: bool = True
339
+ ):
340
+ # 0. Default height and width to unet
341
+ device = self.device
342
+ dtype = self.dtype
343
+
344
+ batch_size = len(images)
345
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
346
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
347
+
348
+ do_text_classifier_free_guidance = (text_guidance_scale > 1.0)
349
+ do_audio_classifier_free_guidance = (audio_guidance_scale > 1.0)
350
+
351
+ # 1. Encoder text into ((k b) f n d)
352
+ text_encodings = self.encode_text(
353
+ text_encodings=text_encodings,
354
+ device=device,
355
+ dtype=dtype,
356
+ do_text_classifier_free_guidance=do_text_classifier_free_guidance,
357
+ do_audio_classifier_free_guidance=do_audio_classifier_free_guidance
358
+ ) # ((k b), n, d)
359
+ text_encodings = repeat(text_encodings, "b n d -> b t n d", t=video_length).to(device=device, dtype=dtype)
360
+
361
+ # 2. Encode audio
362
+ # audio_encodings: ((k b), n, d)
363
+ # audio_masks: ((k b), s, n)
364
+ audio_encodings, audio_masks = self.encode_audio(
365
+ audios, video_length, do_text_classifier_free_guidance, do_audio_classifier_free_guidance, device, dtype
366
+ )
367
+
368
+ # 3. Prepare image latent
369
+ image = self.image_processor.preprocess(images)
370
+ image_latents = self.encode_latents(image).to(device=device, dtype=dtype) # (b c h w)
371
+
372
+ # 4. Prepare unet noising video latents
373
+ video_latents = self.prepare_video_latents(
374
+ image_latents=image_latents,
375
+ num_channels_latents=self.unet.config.in_channels,
376
+ video_length=video_length,
377
+ height=height,
378
+ width=width,
379
+ dtype=dtype,
380
+ device=device,
381
+ generator=generator,
382
+ ) # (b c f h w)
383
+
384
+ # 5. Prepare timesteps and extra step kwargs
385
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
386
+ timesteps = self.scheduler.timesteps
387
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta=0.0)
388
+
389
+ # 7. Denoising loop
390
+ for i, t in enumerate(self.progress_bar(timesteps)):
391
+ latent_model_input = [video_latents]
392
+ if do_text_classifier_free_guidance:
393
+ latent_model_input.append(video_latents)
394
+ if do_audio_classifier_free_guidance:
395
+ latent_model_input.append(video_latents)
396
+ latent_model_input = torch.cat(latent_model_input)
397
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
398
+
399
+ # predict the noise residual
400
+ noise_pred = self.unet(
401
+ latent_model_input,
402
+ t,
403
+ encoder_hidden_states=text_encodings,
404
+ audio_encoder_hidden_states=audio_encodings,
405
+ audio_attention_mask=audio_masks
406
+ ).sample
407
+
408
+ # perform guidance
409
+ if do_text_classifier_free_guidance and do_audio_classifier_free_guidance: # dual cfg
410
+ noise_pred_uncond, noise_pred_text, noise_pred_text_audio = noise_pred.chunk(3)
411
+ noise_pred = noise_pred_uncond + \
412
+ text_guidance_scale * (noise_pred_text - noise_pred_uncond) + \
413
+ audio_guidance_scale * (noise_pred_text_audio - noise_pred_text)
414
+ elif do_text_classifier_free_guidance: # only text cfg
415
+ noise_pred_audio, noise_pred_text_audio = noise_pred.chunk(2)
416
+ noise_pred = noise_pred_audio + \
417
+ text_guidance_scale * (noise_pred_text_audio - noise_pred_audio)
418
+ elif do_audio_classifier_free_guidance: # only audio cfg
419
+ noise_pred_text, noise_pred_text_audio = noise_pred.chunk(2)
420
+ noise_pred = noise_pred_text + \
421
+ audio_guidance_scale * (noise_pred_text_audio - noise_pred_text)
422
+
423
+ # First frame latent will always server as unchanged condition
424
+ video_latents[:, :, 1:, :, :] = self.scheduler.step(noise_pred[:, :, 1:, :, :], t,
425
+ video_latents[:, :, 1:, :, :],
426
+ **extra_step_kwargs).prev_sample
427
+ video_latents = video_latents.contiguous()
428
+
429
+ # 8. Post-processing
430
+ video_latents = rearrange(video_latents, "b c f h w -> (b f) c h w")
431
+ videos = self.decode_latents(video_latents).detach().cpu()
432
+ videos = rearrange(videos, "(b f) c h w -> b f c h w", f=video_length) # value range [0, 1]
433
+
434
+ if not return_dict:
435
+ return videos
436
+
437
+ return {"videos": videos}
438
+
439
+
440
+ def load_and_transform_images_stable_diffusion(
441
+ images: Union[List[np.ndarray], torch.Tensor, np.ndarray],
442
+ size=512,
443
+ flip=False,
444
+ randcrop=False,
445
+ normalize=True
446
+ ):
447
+ """
448
+ @images: (List of) np.uint8 images of shape (h, w, 3)
449
+ or tensor of shape (b, c, h, w) in [0., 1.0]
450
+
451
+ """
452
+
453
+ assert isinstance(images, (List, torch.Tensor, np.ndarray)), type(images)
454
+ if isinstance(images, List):
455
+ assert isinstance(images[0], np.ndarray)
456
+ assert images[0].dtype == np.uint8
457
+ assert images[0].shape[2] == 3
458
+
459
+ # convert np images into torch float tensor
460
+ images = torch.from_numpy(
461
+ rearrange(np.stack(images, axis=0), "f h w c -> f c h w")
462
+ ).float() / 255.
463
+ elif isinstance(images, np.ndarray):
464
+ assert isinstance(images, np.ndarray)
465
+ assert images.dtype == np.uint8
466
+ assert images.shape[3] == 3
467
+
468
+ # convert np images into torch float tensor
469
+ images = torch.from_numpy(
470
+ rearrange(images, "f h w c -> f c h w")
471
+ ).float() / 255.
472
+
473
+ assert images.shape[1] == 3
474
+ assert torch.all(images <= 1.0) and torch.all(images >= 0.0)
475
+
476
+ h, w = images.shape[-2:]
477
+ if isinstance(size, int):
478
+ target_h, target_w = size, size
479
+ else:
480
+ target_h, target_w = size
481
+
482
+ # first crop the image
483
+ target_aspect_ratio = float(target_h) / target_w
484
+ curr_aspect_ratio = float(h) / w
485
+ if target_aspect_ratio >= curr_aspect_ratio: # trim w
486
+ trimmed_w = int(h / target_aspect_ratio)
487
+ images = images[:, :, :, (w - trimmed_w) // 2: (w - trimmed_w) // 2 + trimmed_w]
488
+ else: # trim h
489
+ trimmed_h = int(w * target_aspect_ratio)
490
+ images = images[:, :, (h - trimmed_h) // 2: (h - trimmed_h) // 2 + trimmed_h]
491
+
492
+ transform_list = [
493
+ transforms.Resize(
494
+ size,
495
+ interpolation=transforms.InterpolationMode.BILINEAR,
496
+ antialias=True
497
+ ),
498
+ ]
499
+
500
+ # assert not randcrop
501
+ if randcrop:
502
+ transform_list.append(transforms.RandomCrop(size))
503
+ else:
504
+ transform_list.append(transforms.CenterCrop(size))
505
+
506
+ if flip:
507
+ transform_list.append(transforms.RandomHorizontalFlip(p=1.0))
508
+
509
+ if normalize:
510
+ transform_list.append(transforms.Normalize([0.5], [0.5]))
511
+
512
+ data_transform = transforms.Compose(transform_list)
513
+
514
+ images = data_transform(images)
515
+ return images
516
+
517
+
518
+ def load_image(image_path):
519
+ image = PIL.Image.open(image_path).convert('RGB')
520
+
521
+ width, height = image.size
522
+ if width < height:
523
+ new_width = 256
524
+ new_height = int((256 / width) * height)
525
+ else:
526
+ new_height = 256
527
+ new_width = int((256 / height) * width)
528
+
529
+ # Rescale the image
530
+ image = image.resize((new_width, new_height), PIL.Image.LANCZOS)
531
+
532
+ # Crop a 256x256 square from the center
533
+ left = (new_width - 256) / 2
534
+ top = (new_height - 256) / 2
535
+ right = (new_width + 256) / 2
536
+ bottom = (new_height + 256) / 2
537
+ image = image.crop((left, top, right, bottom))
538
+
539
+ return image
540
+
541
+
542
+ def load_audio(audio_path):
543
+ audio, audio_sr = torchaudio.load(audio_path)
544
+ if audio.ndim == 1: audio = audio.unsqueeze(0)
545
+ else:
546
+ audio = audio.mean(dim=0).unsqueeze(0)
547
+ audio = torchaudio.functional.resample(audio, orig_freq=audio_sr, new_freq=16000)
548
+ audio = audio[:, :32000].contiguous().float()
549
+ if audio.shape[1] < 32000:
550
+ audio = torch.cat([audio, torch.ones(1, 32000-audio.shape[1]).float()], dim=1)
551
+
552
+ return audio.contiguous()
553
+
554
+
555
+ @torch.no_grad()
556
+ def generate_videos(
557
+ pipeline,
558
+ image_path: str = '',
559
+ audio_path: str = '',
560
+ category_text_encoding: Optional[torch.Tensor] = None,
561
+ image_size: Tuple[int, int] = (256, 256),
562
+ video_fps: int = 6,
563
+ video_num_frame: int = 12,
564
+ audio_guidance_scale: float = 4.0,
565
+ denoising_step: int = 20,
566
+ text_guidance_scale: float = 1.0,
567
+ seed: int = 0,
568
+ save_path: str = "",
569
+ device: torch.device = torch.device("cuda"),
570
+ ):
571
+ image = load_image(image_path)
572
+ audio = load_audio(audio_path)
573
+
574
+ generator = torch.Generator(device=device)
575
+ generator.manual_seed(seed)
576
+ generated_video = pipeline(
577
+ images=[image],
578
+ audios=[audio],
579
+ text_encodings=[category_text_encoding],
580
+ video_length=video_num_frame,
581
+ height=image_size[0],
582
+ width=image_size[1],
583
+ num_inference_steps=denoising_step,
584
+ audio_guidance_scale=audio_guidance_scale,
585
+ text_guidance_scale=text_guidance_scale,
586
+ generator=generator,
587
+ return_dict=False
588
+ )[0] # (f c h w) in range [0, 1]
589
+ generated_video = (generated_video.permute(0, 2, 3, 1).contiguous() * 255).byte()
590
+
591
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
592
+ torchvision.io.write_video(
593
+ filename=save_path,
594
+ video_array=generated_video,
595
+ fps=video_fps,
596
+ audio_array=audio,
597
+ audio_fps=16000,
598
+ audio_codec="aac"
599
+ )
600
+
601
+ return
602
+
pretrained/openai-clip-l_null_text_encoding.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06170f5fa389ab44a9e12c27146a2b6569cdea6808a58ba341ce50903939da98
3
+ size 237430
pretrained/stable-diffusion-v1-5/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.6.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "num_train_timesteps": 1000,
8
+ "set_alpha_to_one": false,
9
+ "skip_prk_steps": true,
10
+ "steps_offset": 1,
11
+ "trained_betas": null,
12
+ "clip_sample": false
13
+ }
pretrained/stable-diffusion-v1-5/vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 512,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5
3
+ size 334707217
pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.fp16.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7643b3e40b9f128eda5fe174fea73c3ef3903562651fb344a79439709c2e503
3
+ size 167405651
pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fbcf0ebe55a0984f5a5e00d8c4521d52359af7229bb4d81890039d2aa16dd7c
3
+ size 167335342
pretrained/stable-diffusion-v1-5/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2b5134f4dbc140d9c11f11cba3233099e00af40f262f136c691fb7d38d2194c
3
+ size 334643276
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.32.1
2
+ diffusers==0.29.2
3
+ einops==0.8.0
4
+ ftfy==6.2.0
5
+ imageio==2.34.2
6
+ iopath==0.1.10
7
+ pytorchvideo==0.1.5
8
+ timm==1.0.7
9
+ tqdm==4.66.4
10
+ transformers==4.42.4
11
+ wandb==0.17.5
unet.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+ import json
18
+ from einops import repeat
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders import UNet2DConditionLoadersMixin
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
29
+ from diffusers.models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+
32
+ from unet_blocks import (
33
+ all_modules,
34
+ get_down_block,
35
+ get_up_block,
36
+ get_mid_block,
37
+ )
38
+
39
+ from unet_utils import FFInflatedConv3d
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ @dataclass
45
+ class UNet3DConditionOutput(BaseOutput):
46
+ """
47
+ Args:
48
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
49
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
50
+ """
51
+
52
+ sample: torch.FloatTensor
53
+
54
+
55
+ class AudioUNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
56
+ r"""
57
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
58
+ and returns sample shaped output.
59
+
60
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
61
+ implements for all the models (such as downloading or saving, etc.)
62
+
63
+ Parameters:
64
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
65
+ Height and width of input/output sample.
66
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
67
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
68
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
69
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
70
+ Whether to flip the sin to cos in the time embedding.
71
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
72
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
73
+ The tuple of downsample blocks to use.
74
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
75
+ The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
76
+ mid block layer if `None`.
77
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
78
+ The tuple of upsample blocks to use.
79
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
80
+ Whether to include self-attention in the basic transformer blocks, see
81
+ [`~models.attention.BasicTransformerBlock`].
82
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
83
+ The tuple of output channels for each block.
84
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
85
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
86
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
87
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
88
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
89
+ If `None`, it will skip the normalization and activation layers in post-processing
90
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
91
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
92
+ The dimension of the cross attention features.
93
+ encoder_hid_dim (`int`, *optional*, defaults to None):
94
+ If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
95
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
96
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
97
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
98
+ class_embed_type (`str`, *optional*, defaults to None):
99
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
100
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
101
+ addition_embed_type (`str`, *optional*, defaults to None):
102
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
103
+ "text". "text" will use the `TextTimeEmbedding` layer.
104
+ num_class_embeds (`int`, *optional*, defaults to None):
105
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
106
+ class conditioning with `class_embed_type` equal to `None`.
107
+ time_embedding_type (`str`, *optional*, default to `positional`):
108
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
109
+ time_embedding_dim (`int`, *optional*, default to `None`):
110
+ An optional override for the dimension of the projected time embedding.
111
+ time_embedding_act_fn (`str`, *optional*, default to `None`):
112
+ Optional activation function to use on the time embeddings only one time before they as passed to the rest
113
+ of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
114
+ timestep_post_act (`str, *optional*, default to `None`):
115
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
116
+ time_cond_proj_dim (`int`, *optional*, default to `None`):
117
+ The dimension of `cond_proj` layer in timestep embedding.
118
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
119
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
120
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
121
+ using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
122
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
123
+ embeddings with the class embeddings.
124
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
125
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
126
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
127
+ `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
128
+ default to `False`.
129
+ """
130
+
131
+ _supports_gradient_checkpointing = True
132
+
133
+ @register_to_config
134
+ def __init__(
135
+ self,
136
+ sample_size: Optional[int] = None,
137
+ in_channels: int = 4,
138
+ out_channels: int = 4,
139
+ center_input_sample: bool = False,
140
+ flip_sin_to_cos: bool = True,
141
+ freq_shift: int = 0,
142
+ down_block_types: Tuple[str] = (
143
+ "FFSpatioAudioTempCrossAttnDownBlock3D",
144
+ "FFSpatioAudioTempCrossAttnDownBlock3D",
145
+ "FFSpatioAudioTempCrossAttnDownBlock3D",
146
+ "FFSpatioTempResDownBlock3D",
147
+ ),
148
+ mid_block_type: Optional[str] = "FFSpatioAudioTempCrossAttnUNetMidBlock3D",
149
+ up_block_types: Tuple[str] = (
150
+ "FFSpatioTempResUpBlock3D",
151
+ "FFSpatioAudioTempCrossAttnUpBlock3D",
152
+ "FFSpatioAudioTempCrossAttnUpBlock3D",
153
+ "FFSpatioAudioTempCrossAttnUpBlock3D"
154
+ ),
155
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
156
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
157
+ layers_per_block: Union[int, Tuple[int]] = 2,
158
+ downsample_padding: int = 1,
159
+ mid_block_scale_factor: float = 1,
160
+ act_fn: str = "silu",
161
+ norm_num_groups: Optional[int] = 32,
162
+ norm_eps: float = 1e-5,
163
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
164
+ encoder_hid_dim: Optional[int] = None,
165
+ attention_head_dim: Union[int, Tuple[int]] = 8,
166
+ dual_cross_attention: bool = False,
167
+ use_linear_projection: bool = False,
168
+ class_embed_type: Optional[str] = None,
169
+ addition_embed_type: Optional[str] = None,
170
+ num_class_embeds: Optional[int] = None,
171
+ upcast_attention: bool = False,
172
+ resnet_time_scale_shift: str = "default",
173
+ resnet_skip_time_act: bool = False,
174
+ resnet_out_scale_factor: int = 1.0,
175
+ time_embedding_type: str = "positional",
176
+ time_embedding_dim: Optional[int] = None,
177
+ time_embedding_act_fn: Optional[str] = None,
178
+ timestep_post_act: Optional[str] = None,
179
+ time_cond_proj_dim: Optional[int] = None,
180
+ conv_in_kernel: int = 3,
181
+ conv_out_kernel: int = 3,
182
+ projection_class_embeddings_input_dim: Optional[int] = None,
183
+ class_embeddings_concat: bool = False,
184
+ mid_block_only_cross_attention: Optional[bool] = None,
185
+ cross_attention_norm: Optional[str] = None,
186
+ addition_embed_type_num_heads=64,
187
+ audio_cross_attention_dim: int = 768,
188
+ ):
189
+ super().__init__()
190
+
191
+ self.sample_size = sample_size
192
+
193
+ # Check inputs
194
+ if len(down_block_types) != len(up_block_types):
195
+ raise ValueError(
196
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
197
+ )
198
+
199
+ if len(block_out_channels) != len(down_block_types):
200
+ raise ValueError(
201
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
202
+ )
203
+
204
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
205
+ raise ValueError(
206
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
207
+ )
208
+
209
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
210
+ raise ValueError(
211
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
212
+ )
213
+
214
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
215
+ raise ValueError(
216
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
217
+ )
218
+
219
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
220
+ raise ValueError(
221
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
222
+ )
223
+
224
+ # input
225
+ conv_in_padding = (conv_in_kernel - 1) // 2
226
+ self.conv_in = FFInflatedConv3d(
227
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
228
+ )
229
+
230
+ # time
231
+ if time_embedding_type == "fourier":
232
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
233
+ if time_embed_dim % 2 != 0:
234
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
235
+ self.time_proj = GaussianFourierProjection(
236
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
237
+ )
238
+ timestep_input_dim = time_embed_dim
239
+ elif time_embedding_type == "positional":
240
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
241
+
242
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
243
+ timestep_input_dim = block_out_channels[0]
244
+ else:
245
+ raise ValueError(
246
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
247
+ )
248
+
249
+ self.time_embedding = TimestepEmbedding(
250
+ timestep_input_dim,
251
+ time_embed_dim,
252
+ act_fn=act_fn,
253
+ post_act_fn=timestep_post_act,
254
+ cond_proj_dim=time_cond_proj_dim,
255
+ )
256
+
257
+ if encoder_hid_dim is not None:
258
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
259
+ else:
260
+ self.encoder_hid_proj = None
261
+
262
+ # class embedding
263
+ if class_embed_type is None and num_class_embeds is not None:
264
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
265
+ elif class_embed_type == "timestep":
266
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
267
+ elif class_embed_type == "identity":
268
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
269
+ elif class_embed_type == "projection":
270
+ if projection_class_embeddings_input_dim is None:
271
+ raise ValueError(
272
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
273
+ )
274
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
275
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
276
+ # 2. it projects from an arbitrary input dimension.
277
+ #
278
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
279
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
280
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
281
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
282
+ elif class_embed_type == "simple_projection":
283
+ if projection_class_embeddings_input_dim is None:
284
+ raise ValueError(
285
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
286
+ )
287
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
288
+ else:
289
+ self.class_embedding = None
290
+
291
+ if addition_embed_type == "text":
292
+ if encoder_hid_dim is not None:
293
+ text_time_embedding_from_dim = encoder_hid_dim
294
+ else:
295
+ text_time_embedding_from_dim = cross_attention_dim
296
+
297
+ self.add_embedding = TextTimeEmbedding(
298
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
299
+ )
300
+ elif addition_embed_type is not None:
301
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
302
+
303
+ if time_embedding_act_fn is None:
304
+ self.time_embed_act = None
305
+ elif time_embedding_act_fn == "swish":
306
+ self.time_embed_act = lambda x: F.silu(x)
307
+ elif time_embedding_act_fn == "mish":
308
+ self.time_embed_act = nn.Mish()
309
+ elif time_embedding_act_fn == "silu":
310
+ self.time_embed_act = nn.SiLU()
311
+ elif time_embedding_act_fn == "gelu":
312
+ self.time_embed_act = nn.GELU()
313
+ else:
314
+ raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
315
+
316
+ self.down_blocks = nn.ModuleList([])
317
+ self.up_blocks = nn.ModuleList([])
318
+
319
+ if isinstance(only_cross_attention, bool):
320
+ if mid_block_only_cross_attention is None:
321
+ mid_block_only_cross_attention = only_cross_attention
322
+
323
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
324
+
325
+ if mid_block_only_cross_attention is None:
326
+ mid_block_only_cross_attention = False
327
+
328
+ if isinstance(attention_head_dim, int):
329
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
330
+
331
+ if isinstance(cross_attention_dim, int):
332
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
333
+
334
+ if isinstance(layers_per_block, int):
335
+ layers_per_block = [layers_per_block] * len(down_block_types)
336
+
337
+ if class_embeddings_concat:
338
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
339
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
340
+ # regular time embeddings
341
+ blocks_time_embed_dim = time_embed_dim * 2
342
+ else:
343
+ blocks_time_embed_dim = time_embed_dim
344
+
345
+ # down
346
+ output_channel = block_out_channels[0]
347
+ for i, down_block_type in enumerate(down_block_types):
348
+ input_channel = output_channel
349
+ output_channel = block_out_channels[i]
350
+ is_final_block = i == len(block_out_channels) - 1
351
+
352
+ down_block = get_down_block(
353
+ down_block_type,
354
+ num_layers=layers_per_block[i],
355
+ in_channels=input_channel,
356
+ out_channels=output_channel,
357
+ temb_channels=blocks_time_embed_dim,
358
+ add_downsample=not is_final_block,
359
+ resnet_eps=norm_eps,
360
+ resnet_act_fn=act_fn,
361
+ resnet_groups=norm_num_groups,
362
+ cross_attention_dim=cross_attention_dim[i],
363
+ attn_num_head_channels=attention_head_dim[i],
364
+ downsample_padding=downsample_padding,
365
+ dual_cross_attention=dual_cross_attention,
366
+ use_linear_projection=use_linear_projection,
367
+ only_cross_attention=only_cross_attention[i],
368
+ upcast_attention=upcast_attention,
369
+ resnet_time_scale_shift=resnet_time_scale_shift,
370
+ audio_cross_attention_dim=audio_cross_attention_dim
371
+ )
372
+ self.down_blocks.append(down_block)
373
+
374
+ # mid
375
+ if mid_block_type is None:
376
+ self.mid_block = None
377
+ else:
378
+ self.mid_block = get_mid_block(
379
+ mid_block_type=mid_block_type,
380
+ in_channels=block_out_channels[-1],
381
+ temb_channels=blocks_time_embed_dim,
382
+ resnet_eps=norm_eps,
383
+ resnet_act_fn=act_fn,
384
+ output_scale_factor=mid_block_scale_factor,
385
+ resnet_time_scale_shift=resnet_time_scale_shift,
386
+ cross_attention_dim=cross_attention_dim[-1],
387
+ attn_num_head_channels=attention_head_dim[-1],
388
+ resnet_groups=norm_num_groups,
389
+ dual_cross_attention=dual_cross_attention,
390
+ use_linear_projection=use_linear_projection,
391
+ upcast_attention=upcast_attention,
392
+ audio_cross_attention_dim=audio_cross_attention_dim
393
+ )
394
+
395
+ # count how many layers upsample the images
396
+ self.num_upsamplers = 0
397
+
398
+ # up
399
+ reversed_block_out_channels = list(reversed(block_out_channels))
400
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
401
+ reversed_layers_per_block = list(reversed(layers_per_block))
402
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
403
+ only_cross_attention = list(reversed(only_cross_attention))
404
+
405
+ output_channel = reversed_block_out_channels[0]
406
+ for i, up_block_type in enumerate(up_block_types):
407
+ is_final_block = i == len(block_out_channels) - 1
408
+
409
+ prev_output_channel = output_channel
410
+ output_channel = reversed_block_out_channels[i]
411
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
412
+
413
+ # add upsample block for all BUT final layer
414
+ if not is_final_block:
415
+ add_upsample = True
416
+ self.num_upsamplers += 1
417
+ else:
418
+ add_upsample = False
419
+
420
+ up_block = get_up_block(
421
+ up_block_type,
422
+ num_layers=reversed_layers_per_block[i] + 1,
423
+ in_channels=input_channel,
424
+ out_channels=output_channel,
425
+ prev_output_channel=prev_output_channel,
426
+ temb_channels=blocks_time_embed_dim,
427
+ add_upsample=add_upsample,
428
+ resnet_eps=norm_eps,
429
+ resnet_act_fn=act_fn,
430
+ resnet_groups=norm_num_groups,
431
+ cross_attention_dim=reversed_cross_attention_dim[i],
432
+ attn_num_head_channels=reversed_attention_head_dim[i],
433
+ dual_cross_attention=dual_cross_attention,
434
+ use_linear_projection=use_linear_projection,
435
+ only_cross_attention=only_cross_attention[i],
436
+ upcast_attention=upcast_attention,
437
+ resnet_time_scale_shift=resnet_time_scale_shift,
438
+ audio_cross_attention_dim=audio_cross_attention_dim
439
+ )
440
+ self.up_blocks.append(up_block)
441
+
442
+ # out
443
+ if norm_num_groups is not None:
444
+ self.conv_norm_out = nn.GroupNorm(
445
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
446
+ )
447
+
448
+ if act_fn == "swish":
449
+ self.conv_act = lambda x: F.silu(x)
450
+ elif act_fn == "mish":
451
+ self.conv_act = nn.Mish()
452
+ elif act_fn == "silu":
453
+ self.conv_act = nn.SiLU()
454
+ elif act_fn == "gelu":
455
+ self.conv_act = nn.GELU()
456
+ else:
457
+ raise ValueError(f"Unsupported activation function: {act_fn}")
458
+
459
+ else:
460
+ self.conv_norm_out = None
461
+ self.conv_act = None
462
+
463
+ conv_out_padding = (conv_out_kernel - 1) // 2
464
+ self.conv_out = FFInflatedConv3d(
465
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
466
+ )
467
+
468
+ @property
469
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
470
+ r"""
471
+ Returns:
472
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
473
+ indexed by its weight name.
474
+ """
475
+ # set recursively
476
+ processors = {}
477
+
478
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
479
+ if hasattr(module, "set_processor"):
480
+ processors[f"{name}.processor"] = module.processor
481
+
482
+ for sub_name, child in module.named_children():
483
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
484
+
485
+ return processors
486
+
487
+ for name, module in self.named_children():
488
+ fn_recursive_add_processors(name, module, processors)
489
+
490
+ return processors
491
+
492
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
493
+ r"""
494
+ Parameters:
495
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
496
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
497
+ of **all** `Attention` layers.
498
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
499
+
500
+ """
501
+ count = len(self.attn_processors.keys())
502
+
503
+ if isinstance(processor, dict) and len(processor) != count:
504
+ raise ValueError(
505
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
506
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
507
+ )
508
+
509
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
510
+ if hasattr(module, "set_processor"):
511
+ if not isinstance(processor, dict):
512
+ module.set_processor(processor)
513
+ else:
514
+ module.set_processor(processor.pop(f"{name}.processor"))
515
+
516
+ for sub_name, child in module.named_children():
517
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
518
+
519
+ for name, module in self.named_children():
520
+ fn_recursive_attn_processor(name, module, processor)
521
+
522
+ def set_default_attn_processor(self):
523
+ """
524
+ Disables custom attention processors and sets the default attention implementation.
525
+ """
526
+ self.set_attn_processor(AttnProcessor())
527
+
528
+ def set_attention_slice(self, slice_size):
529
+ r"""
530
+ Enable sliced attention computation.
531
+
532
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
533
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
534
+
535
+ Args:
536
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
537
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
538
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
539
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
540
+ must be a multiple of `slice_size`.
541
+ """
542
+ sliceable_head_dims = []
543
+
544
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
545
+ if hasattr(module, "set_attention_slice"):
546
+ sliceable_head_dims.append(module.sliceable_head_dim)
547
+
548
+ for child in module.children():
549
+ fn_recursive_retrieve_sliceable_dims(child)
550
+
551
+ # retrieve number of attention layers
552
+ for module in self.children():
553
+ fn_recursive_retrieve_sliceable_dims(module)
554
+
555
+ num_sliceable_layers = len(sliceable_head_dims)
556
+
557
+ if slice_size == "auto":
558
+ # half the attention head size is usually a good trade-off between
559
+ # speed and memory
560
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
561
+ elif slice_size == "max":
562
+ # make smallest slice possible
563
+ slice_size = num_sliceable_layers * [1]
564
+
565
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
566
+
567
+ if len(slice_size) != len(sliceable_head_dims):
568
+ raise ValueError(
569
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
570
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
571
+ )
572
+
573
+ for i in range(len(slice_size)):
574
+ size = slice_size[i]
575
+ dim = sliceable_head_dims[i]
576
+ if size is not None and size > dim:
577
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
578
+
579
+ # Recursively walk through all the children.
580
+ # Any children which exposes the set_attention_slice method
581
+ # gets the message
582
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
583
+ if hasattr(module, "set_attention_slice"):
584
+ module.set_attention_slice(slice_size.pop())
585
+
586
+ for child in module.children():
587
+ fn_recursive_set_attention_slice(child, slice_size)
588
+
589
+ reversed_slice_size = list(reversed(slice_size))
590
+ for module in self.children():
591
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
592
+
593
+ def _set_gradient_checkpointing(self, module, value=False):
594
+ if isinstance(module, tuple(all_modules)):
595
+ module.gradient_checkpointing = value
596
+
597
+ def forward(
598
+ self,
599
+ sample: torch.FloatTensor,
600
+ timestep: Union[torch.Tensor, float, int],
601
+ encoder_hidden_states: torch.Tensor,
602
+ audio_encoder_hidden_states: Optional[torch.Tensor] = None,
603
+ class_labels: Optional[torch.Tensor] = None,
604
+ timestep_cond: Optional[torch.Tensor] = None,
605
+ attention_mask: Optional[torch.Tensor] = None,
606
+ audio_attention_mask: Optional[torch.Tensor] = None,
607
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
608
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
609
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
610
+ return_dict: bool = True,
611
+ ) -> Union[UNet3DConditionOutput, Tuple]:
612
+ r"""
613
+ Args:
614
+ sample (`torch.FloatTensor`): (batch, channel, frame, height, width) noisy inputs tensor
615
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
616
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
617
+ return_dict (`bool`, *optional*, defaults to `True`):
618
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
619
+ cross_attention_kwargs (`dict`, *optional*):
620
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
621
+ `self.processor` in
622
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
623
+
624
+ Returns:
625
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
626
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
627
+ returning a tuple, the first element is the sample tensor.
628
+ """
629
+ assert sample.ndim == 5, sample.size()
630
+ video_length = sample.shape[2]
631
+
632
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
633
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
634
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
635
+ # on the fly if necessary.
636
+ default_overall_up_factor = 2 ** self.num_upsamplers
637
+
638
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
639
+ forward_upsample_size = False
640
+ upsample_size = None
641
+
642
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
643
+ logger.info("Forward upsample size to force interpolation output size.")
644
+ forward_upsample_size = True
645
+
646
+ # prepare attention_mask
647
+ if attention_mask is not None:
648
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
649
+ attention_mask = attention_mask.unsqueeze(1)
650
+
651
+ # 0. center input if necessary
652
+ if self.config.center_input_sample:
653
+ sample = 2 * sample - 1.0
654
+
655
+ # 1. time
656
+ timesteps = timestep
657
+ if not torch.is_tensor(timesteps):
658
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
659
+ # This would be a good case for the `match` statement (Python 3.10+)
660
+ is_mps = sample.device.type == "mps"
661
+ if isinstance(timestep, float):
662
+ dtype = torch.float32 if is_mps else torch.float64
663
+ else:
664
+ dtype = torch.int32 if is_mps else torch.int64
665
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
666
+ elif len(timesteps.shape) == 0:
667
+ timesteps = timesteps[None].to(sample.device)
668
+
669
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
670
+ timesteps = timesteps.expand(sample.shape[0])
671
+
672
+ t_emb = self.time_proj(timesteps)
673
+
674
+ # `Timesteps` does not contain any weights and will always return f32 tensors
675
+ # but time_embedding might actually be running in fp16. so we need to cast here.
676
+ # there might be better ways to encapsulate this.
677
+ t_emb = t_emb.to(dtype=self.dtype)
678
+
679
+ emb = self.time_embedding(t_emb, timestep_cond)
680
+ emb = repeat(emb, "b c -> b f c", f=video_length)
681
+
682
+ if self.class_embedding is not None:
683
+ if class_labels is None:
684
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
685
+
686
+ if self.config.class_embed_type == "timestep":
687
+ class_labels = self.time_proj(class_labels)
688
+
689
+ # `Timesteps` does not contain any weights and will always return f32 tensors
690
+ # there might be better ways to encapsulate this.
691
+ class_labels = class_labels.to(dtype=sample.dtype)
692
+
693
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
694
+
695
+ if self.config.class_embeddings_concat:
696
+ emb = torch.cat([emb, class_emb], dim=-1)
697
+ else:
698
+ emb = emb + class_emb
699
+
700
+ if self.config.addition_embed_type == "text":
701
+ aug_emb = self.add_embedding(encoder_hidden_states)
702
+ emb = emb + aug_emb
703
+
704
+ if self.time_embed_act is not None:
705
+ emb = self.time_embed_act(emb)
706
+
707
+ if self.encoder_hid_proj is not None:
708
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
709
+
710
+ # 2. pre-process
711
+ sample = self.conv_in(sample)
712
+
713
+ # 3. down
714
+ down_block_res_samples = (sample,)
715
+ for downsample_block in self.down_blocks:
716
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
717
+ sample, res_samples = downsample_block(
718
+ hidden_states=sample,
719
+ temb=emb,
720
+ encoder_hidden_states=encoder_hidden_states,
721
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
722
+ attention_mask=attention_mask,
723
+ audio_attention_mask=audio_attention_mask,
724
+ cross_attention_kwargs=cross_attention_kwargs,
725
+ )
726
+ else:
727
+ sample, res_samples = downsample_block(
728
+ hidden_states=sample, temb=emb
729
+ )
730
+
731
+ down_block_res_samples += res_samples
732
+
733
+ if down_block_additional_residuals is not None:
734
+ new_down_block_res_samples = ()
735
+
736
+ for down_block_res_sample, down_block_additional_residual in zip(
737
+ down_block_res_samples, down_block_additional_residuals
738
+ ):
739
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
740
+ new_down_block_res_samples += (down_block_res_sample,)
741
+
742
+ down_block_res_samples = new_down_block_res_samples
743
+
744
+ # 4. mid
745
+ if self.mid_block is not None:
746
+ sample = self.mid_block(
747
+ sample,
748
+ emb,
749
+ encoder_hidden_states=encoder_hidden_states,
750
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
751
+ attention_mask=attention_mask,
752
+ audio_attention_mask=audio_attention_mask,
753
+ cross_attention_kwargs=cross_attention_kwargs,
754
+ )
755
+
756
+ if mid_block_additional_residual is not None:
757
+ sample = sample + mid_block_additional_residual
758
+
759
+ # 5. up
760
+ for i, upsample_block in enumerate(self.up_blocks):
761
+ is_final_block = i == len(self.up_blocks) - 1
762
+
763
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
764
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
765
+
766
+ # if we have not reached the final block and need to forward the
767
+ # upsample size, we do it here
768
+ if not is_final_block and forward_upsample_size:
769
+ upsample_size = down_block_res_samples[-1].shape[2:]
770
+
771
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
772
+ sample = upsample_block(
773
+ hidden_states=sample,
774
+ temb=emb,
775
+ res_hidden_states_tuple=res_samples,
776
+ encoder_hidden_states=encoder_hidden_states,
777
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
778
+ cross_attention_kwargs=cross_attention_kwargs,
779
+ upsample_size=upsample_size,
780
+ attention_mask=attention_mask,
781
+ audio_attention_mask=audio_attention_mask,
782
+ )
783
+ else:
784
+ sample = upsample_block(
785
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
786
+ )
787
+
788
+ # 6. post-process
789
+ if self.conv_norm_out:
790
+ sample = self.conv_norm_out(sample)
791
+ sample = self.conv_act(sample)
792
+ sample = self.conv_out(sample)
793
+
794
+ if not return_dict:
795
+ return (sample,)
796
+
797
+ return UNet3DConditionOutput(sample=sample)
798
+
799
+ @classmethod
800
+ def from_pretrained_2d(cls, config3d, pretrained_model_path, subfolder=None):
801
+ # 1. Build 3D config from pretrained 2D config
802
+ if subfolder is not None:
803
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
804
+ config2d_file = os.path.join(pretrained_model_path, 'config.json')
805
+ assert os.path.isfile(config2d_file), f"{config2d_file} does not exist"
806
+
807
+ with open(config2d_file, "r") as f:
808
+ config2d = json.load(f)
809
+ config2d["_class_name"] = cls.__name__
810
+ config2d["down_block_types"] = tuple(config3d["down_block_types"])
811
+ config2d["up_block_types"] = tuple(config3d["up_block_types"])
812
+ config2d["mid_block_type"] = config3d["mid_block_type"]
813
+ if "cross_attention_dim" in config3d: config2d["cross_attention_dim"] = config3d["cross_attention_dim"]
814
+ if "audio_cross_attention_dim" in config3d: config2d["audio_cross_attention_dim"] = config3d[
815
+ "audio_cross_attention_dim"]
816
+
817
+ # 2. Build 3D model from updated 3D config
818
+ model = cls.from_config(config2d)
819
+
820
+ # 3. Load in weights from pretrained 2D nets
821
+ from diffusers.utils import WEIGHTS_NAME
822
+ model2d_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
823
+ assert os.path.isfile(model2d_file), f"{model2d_file} does not exist"
824
+ pretrained_2d_state_dict = torch.load(model2d_file, map_location="cpu")
825
+
826
+ # Add new 3D weights into pretrained 2d state_dict, to be compatible with 3D model
827
+ for k, v in model.state_dict().items():
828
+ # all '_temp' temporal weights are initialized by pretrained 2D models
829
+ if '_temp' in k:
830
+ pretrained_2d_state_dict.update({k: v})
831
+ # add new weights into pretrained 2D state_dict
832
+ elif k not in pretrained_2d_state_dict:
833
+ pretrained_2d_state_dict.update({k: v})
834
+ # if weights has different shape, replace it
835
+ elif pretrained_2d_state_dict[k].shape != v.shape:
836
+ pretrained_2d_state_dict.update({k: v})
837
+ model.load_state_dict(pretrained_2d_state_dict)
838
+
839
+ return model
unet_blocks.py ADDED
@@ -0,0 +1,1084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from ff_spatio_temp_resnet_3d import (
5
+ FFSpatioTempResnetBlock3D, FFSpatioTempResDownsample3D, FFSpatioTempResUpsample3D
6
+ )
7
+ from ff_spatio_temp_transformer_3d import FFSpatioTempTransformer3DModel
8
+ from ff_spatio_audio_temp_transformer_3d import FFSpatioAudioTempTransformer3DModel
9
+
10
+
11
+ def create_custom_forward(module, return_dict=None):
12
+ def custom_forward(*inputs):
13
+ if return_dict is not None:
14
+ return module(*inputs, return_dict=return_dict)
15
+ else:
16
+ return module(*inputs)
17
+
18
+ return custom_forward
19
+
20
+
21
+ def get_down_block(
22
+ down_block_type,
23
+ num_layers,
24
+ in_channels,
25
+ out_channels,
26
+ temb_channels,
27
+ add_downsample,
28
+ resnet_eps,
29
+ resnet_act_fn,
30
+ attn_num_head_channels,
31
+ resnet_groups=None,
32
+ cross_attention_dim=None,
33
+ downsample_padding=None,
34
+ dual_cross_attention=False,
35
+ use_linear_projection=False,
36
+ only_cross_attention=False,
37
+ upcast_attention=False,
38
+ resnet_time_scale_shift="default",
39
+ audio_cross_attention_dim=None
40
+ ):
41
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
42
+ if down_block_type == "FFSpatioTempResDownBlock3D":
43
+ return FFSpatioTempResDownBlock3D(
44
+ num_layers=num_layers,
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ temb_channels=temb_channels,
48
+ add_downsample=add_downsample,
49
+ resnet_eps=resnet_eps,
50
+ resnet_act_fn=resnet_act_fn,
51
+ resnet_groups=resnet_groups,
52
+ downsample_padding=downsample_padding,
53
+ resnet_time_scale_shift=resnet_time_scale_shift
54
+ )
55
+ elif down_block_type == "FFSpatioTempCrossAttnDownBlock3D":
56
+ if cross_attention_dim is None:
57
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
58
+ return FFSpatioTempCrossAttnDownBlock3D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ add_downsample=add_downsample,
64
+ resnet_eps=resnet_eps,
65
+ resnet_act_fn=resnet_act_fn,
66
+ resnet_groups=resnet_groups,
67
+ downsample_padding=downsample_padding,
68
+ cross_attention_dim=cross_attention_dim,
69
+ attn_num_head_channels=attn_num_head_channels,
70
+ dual_cross_attention=dual_cross_attention,
71
+ use_linear_projection=use_linear_projection,
72
+ only_cross_attention=only_cross_attention,
73
+ upcast_attention=upcast_attention,
74
+ resnet_time_scale_shift=resnet_time_scale_shift
75
+ )
76
+ elif down_block_type == "FFSpatioAudioTempCrossAttnDownBlock3D":
77
+ if cross_attention_dim is None:
78
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
79
+ return FFSpatioAudioTempCrossAttnDownBlock3D(
80
+ num_layers=num_layers,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ add_downsample=add_downsample,
85
+ resnet_eps=resnet_eps,
86
+ resnet_act_fn=resnet_act_fn,
87
+ resnet_groups=resnet_groups,
88
+ downsample_padding=downsample_padding,
89
+ cross_attention_dim=cross_attention_dim,
90
+ audio_cross_attention_dim=audio_cross_attention_dim,
91
+ attn_num_head_channels=attn_num_head_channels,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift
97
+ )
98
+ raise ValueError(f"{down_block_type} does not exist.")
99
+
100
+
101
+ def get_up_block(
102
+ up_block_type,
103
+ num_layers,
104
+ in_channels,
105
+ out_channels,
106
+ prev_output_channel,
107
+ temb_channels,
108
+ add_upsample,
109
+ resnet_eps,
110
+ resnet_act_fn,
111
+ attn_num_head_channels,
112
+ resnet_groups=None,
113
+ cross_attention_dim=None,
114
+ dual_cross_attention=False,
115
+ use_linear_projection=False,
116
+ only_cross_attention=False,
117
+ upcast_attention=False,
118
+ resnet_time_scale_shift="default",
119
+ audio_cross_attention_dim=None
120
+ ):
121
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
122
+ if up_block_type == "FFSpatioTempResUpBlock3D":
123
+ return FFSpatioTempResUpBlock3D(
124
+ num_layers=num_layers,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ prev_output_channel=prev_output_channel,
128
+ temb_channels=temb_channels,
129
+ add_upsample=add_upsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ resnet_time_scale_shift=resnet_time_scale_shift
134
+ )
135
+ elif up_block_type == "FFSpatioTempCrossAttnUpBlock3D":
136
+ if cross_attention_dim is None:
137
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
138
+ return FFSpatioTempCrossAttnUpBlock3D(
139
+ num_layers=num_layers,
140
+ in_channels=in_channels,
141
+ out_channels=out_channels,
142
+ prev_output_channel=prev_output_channel,
143
+ temb_channels=temb_channels,
144
+ add_upsample=add_upsample,
145
+ resnet_eps=resnet_eps,
146
+ resnet_act_fn=resnet_act_fn,
147
+ resnet_groups=resnet_groups,
148
+ cross_attention_dim=cross_attention_dim,
149
+ attn_num_head_channels=attn_num_head_channels,
150
+ dual_cross_attention=dual_cross_attention,
151
+ use_linear_projection=use_linear_projection,
152
+ only_cross_attention=only_cross_attention,
153
+ upcast_attention=upcast_attention,
154
+ resnet_time_scale_shift=resnet_time_scale_shift
155
+ )
156
+ elif up_block_type == "FFSpatioAudioTempCrossAttnUpBlock3D":
157
+ if cross_attention_dim is None:
158
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
159
+ return FFSpatioAudioTempCrossAttnUpBlock3D(
160
+ num_layers=num_layers,
161
+ in_channels=in_channels,
162
+ out_channels=out_channels,
163
+ prev_output_channel=prev_output_channel,
164
+ temb_channels=temb_channels,
165
+ add_upsample=add_upsample,
166
+ resnet_eps=resnet_eps,
167
+ resnet_act_fn=resnet_act_fn,
168
+ resnet_groups=resnet_groups,
169
+ cross_attention_dim=cross_attention_dim,
170
+ audio_cross_attention_dim=audio_cross_attention_dim,
171
+ attn_num_head_channels=attn_num_head_channels,
172
+ dual_cross_attention=dual_cross_attention,
173
+ use_linear_projection=use_linear_projection,
174
+ only_cross_attention=only_cross_attention,
175
+ upcast_attention=upcast_attention,
176
+ resnet_time_scale_shift=resnet_time_scale_shift
177
+ )
178
+ raise ValueError(f"{up_block_type} does not exist.")
179
+
180
+
181
+ def get_mid_block(
182
+ mid_block_type,
183
+ in_channels,
184
+ temb_channels,
185
+ resnet_eps,
186
+ resnet_act_fn,
187
+ output_scale_factor,
188
+ resnet_time_scale_shift,
189
+ cross_attention_dim,
190
+ attn_num_head_channels,
191
+ resnet_groups,
192
+ dual_cross_attention,
193
+ use_linear_projection,
194
+ upcast_attention,
195
+ audio_cross_attention_dim=None
196
+ ):
197
+ if mid_block_type == "FFSpatioTempCrossAttnUNetMidBlock3D":
198
+ return FFSpatioTempCrossAttnUNetMidBlock3D(
199
+ in_channels=in_channels,
200
+ temb_channels=temb_channels,
201
+ resnet_eps=resnet_eps,
202
+ resnet_act_fn=resnet_act_fn,
203
+ output_scale_factor=output_scale_factor,
204
+ resnet_time_scale_shift=resnet_time_scale_shift,
205
+ cross_attention_dim=cross_attention_dim,
206
+ attn_num_head_channels=attn_num_head_channels,
207
+ resnet_groups=resnet_groups,
208
+ dual_cross_attention=dual_cross_attention,
209
+ use_linear_projection=use_linear_projection,
210
+ upcast_attention=upcast_attention
211
+ )
212
+ elif mid_block_type == "FFSpatioAudioTempCrossAttnUNetMidBlock3D":
213
+ return FFSpatioAudioTempCrossAttnUNetMidBlock3D(
214
+ in_channels=in_channels,
215
+ temb_channels=temb_channels,
216
+ resnet_eps=resnet_eps,
217
+ resnet_act_fn=resnet_act_fn,
218
+ output_scale_factor=output_scale_factor,
219
+ resnet_time_scale_shift=resnet_time_scale_shift,
220
+ cross_attention_dim=cross_attention_dim,
221
+ audio_cross_attention_dim=audio_cross_attention_dim,
222
+ attn_num_head_channels=attn_num_head_channels,
223
+ resnet_groups=resnet_groups,
224
+ dual_cross_attention=dual_cross_attention,
225
+ use_linear_projection=use_linear_projection,
226
+ upcast_attention=upcast_attention
227
+ )
228
+ raise ValueError(f"{mid_block_type} does not exist.")
229
+
230
+
231
+ ##### Image Condition Blocks #####
232
+
233
+ class FFSpatioTempResDownBlock3D(nn.Module):
234
+ def __init__(
235
+ self,
236
+ in_channels: int,
237
+ out_channels: int,
238
+ temb_channels: int,
239
+ dropout: float = 0.0,
240
+ num_layers: int = 1,
241
+ resnet_eps: float = 1e-6,
242
+ resnet_time_scale_shift: str = "default",
243
+ resnet_act_fn: str = "swish",
244
+ resnet_groups: int = 32,
245
+ resnet_pre_norm: bool = True,
246
+ output_scale_factor=1.0,
247
+ add_downsample=True,
248
+ downsample_padding=1
249
+ ):
250
+ super().__init__()
251
+ resnets = []
252
+
253
+ for i in range(num_layers):
254
+ in_channels = in_channels if i == 0 else out_channels
255
+ resnets.append(
256
+ FFSpatioTempResnetBlock3D(
257
+ in_channels=in_channels,
258
+ out_channels=out_channels,
259
+ temb_channels=temb_channels,
260
+ eps=resnet_eps,
261
+ groups=resnet_groups,
262
+ dropout=dropout,
263
+ time_embedding_norm=resnet_time_scale_shift,
264
+ non_linearity=resnet_act_fn,
265
+ output_scale_factor=output_scale_factor,
266
+ pre_norm=resnet_pre_norm
267
+ )
268
+ )
269
+
270
+ self.resnets = nn.ModuleList(resnets)
271
+
272
+ if add_downsample:
273
+ self.downsamplers = nn.ModuleList(
274
+ [
275
+ FFSpatioTempResDownsample3D(
276
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277
+ )
278
+ ]
279
+ )
280
+ else:
281
+ self.downsamplers = None
282
+
283
+ self.gradient_checkpointing = False
284
+
285
+ def forward(self, hidden_states, temb=None):
286
+ output_states = ()
287
+
288
+ for resnet in self.resnets:
289
+ if self.training and self.gradient_checkpointing:
290
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
291
+ else:
292
+ hidden_states = resnet(hidden_states, temb)
293
+
294
+ output_states += (hidden_states,)
295
+
296
+ if self.downsamplers is not None:
297
+ for downsampler in self.downsamplers:
298
+ hidden_states = downsampler(hidden_states)
299
+
300
+ output_states += (hidden_states,)
301
+
302
+ return hidden_states, output_states
303
+
304
+
305
+ class FFSpatioTempResUpBlock3D(nn.Module):
306
+ def __init__(
307
+ self,
308
+ in_channels: int,
309
+ prev_output_channel: int,
310
+ out_channels: int,
311
+ temb_channels: int,
312
+ dropout: float = 0.0,
313
+ num_layers: int = 1,
314
+ resnet_eps: float = 1e-6,
315
+ resnet_time_scale_shift: str = "default",
316
+ resnet_act_fn: str = "swish",
317
+ resnet_groups: int = 32,
318
+ resnet_pre_norm: bool = True,
319
+ output_scale_factor=1.0,
320
+ add_upsample=True
321
+ ):
322
+ super().__init__()
323
+ resnets = []
324
+
325
+ for i in range(num_layers):
326
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
327
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
328
+
329
+ resnets.append(
330
+ FFSpatioTempResnetBlock3D(
331
+ in_channels=resnet_in_channels + res_skip_channels,
332
+ out_channels=out_channels,
333
+ temb_channels=temb_channels,
334
+ eps=resnet_eps,
335
+ groups=resnet_groups,
336
+ dropout=dropout,
337
+ time_embedding_norm=resnet_time_scale_shift,
338
+ non_linearity=resnet_act_fn,
339
+ output_scale_factor=output_scale_factor,
340
+ pre_norm=resnet_pre_norm
341
+ )
342
+ )
343
+
344
+ self.resnets = nn.ModuleList(resnets)
345
+
346
+ if add_upsample:
347
+ self.upsamplers = nn.ModuleList(
348
+ [FFSpatioTempResUpsample3D(out_channels, use_conv=True, out_channels=out_channels)])
349
+ else:
350
+ self.upsamplers = None
351
+
352
+ self.gradient_checkpointing = False
353
+
354
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
355
+ for resnet in self.resnets:
356
+ # pop res hidden states
357
+ res_hidden_states = res_hidden_states_tuple[-1]
358
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
359
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
360
+
361
+ if self.training and self.gradient_checkpointing:
362
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
363
+ else:
364
+ hidden_states = resnet(hidden_states, temb)
365
+
366
+ if self.upsamplers is not None:
367
+ for upsampler in self.upsamplers:
368
+ hidden_states = upsampler(hidden_states, upsample_size)
369
+
370
+ return hidden_states
371
+
372
+
373
+ class FFSpatioTempCrossAttnUNetMidBlock3D(nn.Module):
374
+ def __init__(
375
+ self,
376
+ in_channels: int,
377
+ temb_channels: int,
378
+ dropout: float = 0.0,
379
+ num_layers: int = 1,
380
+ resnet_eps: float = 1e-6,
381
+ resnet_time_scale_shift: str = "default",
382
+ resnet_act_fn: str = "swish",
383
+ resnet_groups: int = 32,
384
+ resnet_pre_norm: bool = True,
385
+ attn_num_head_channels=1,
386
+ output_scale_factor=1.0,
387
+ cross_attention_dim=1280,
388
+ dual_cross_attention=False,
389
+ use_linear_projection=False,
390
+ upcast_attention=False
391
+ ):
392
+ super().__init__()
393
+
394
+ self.has_cross_attention = True
395
+ self.attn_num_head_channels = attn_num_head_channels
396
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
397
+
398
+ # there is always at least one resnet
399
+ resnets = [
400
+ FFSpatioTempResnetBlock3D(
401
+ in_channels=in_channels,
402
+ out_channels=in_channels,
403
+ temb_channels=temb_channels,
404
+ eps=resnet_eps,
405
+ groups=resnet_groups,
406
+ dropout=dropout,
407
+ time_embedding_norm=resnet_time_scale_shift,
408
+ non_linearity=resnet_act_fn,
409
+ output_scale_factor=output_scale_factor,
410
+ pre_norm=resnet_pre_norm
411
+ )
412
+ ]
413
+ attentions = []
414
+
415
+ for _ in range(num_layers):
416
+ if dual_cross_attention:
417
+ raise NotImplementedError
418
+ attentions.append(
419
+ FFSpatioTempTransformer3DModel(
420
+ attn_num_head_channels,
421
+ in_channels // attn_num_head_channels,
422
+ in_channels=in_channels,
423
+ num_layers=1,
424
+ cross_attention_dim=cross_attention_dim,
425
+ norm_num_groups=resnet_groups,
426
+ use_linear_projection=use_linear_projection,
427
+ upcast_attention=upcast_attention,
428
+ )
429
+ )
430
+ resnets.append(
431
+ FFSpatioTempResnetBlock3D(
432
+ in_channels=in_channels,
433
+ out_channels=in_channels,
434
+ temb_channels=temb_channels,
435
+ eps=resnet_eps,
436
+ groups=resnet_groups,
437
+ dropout=dropout,
438
+ time_embedding_norm=resnet_time_scale_shift,
439
+ non_linearity=resnet_act_fn,
440
+ output_scale_factor=output_scale_factor,
441
+ pre_norm=resnet_pre_norm,
442
+
443
+ )
444
+ )
445
+
446
+ self.attentions = nn.ModuleList(attentions)
447
+ self.resnets = nn.ModuleList(resnets)
448
+
449
+ self.gradient_checkpointing = False
450
+
451
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
452
+ cross_attention_kwargs=None):
453
+ if self.training and self.gradient_checkpointing:
454
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.resnets[0]), hidden_states,
455
+ temb)
456
+ else:
457
+ hidden_states = self.resnets[0](hidden_states, temb)
458
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
459
+ if self.training and self.gradient_checkpointing:
460
+ hidden_states = torch.utils.checkpoint.checkpoint(
461
+ create_custom_forward(attn, return_dict=False),
462
+ hidden_states,
463
+ encoder_hidden_states,
464
+ cross_attention_kwargs
465
+ )[0]
466
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
467
+ else:
468
+ hidden_states = attn(
469
+ hidden_states,
470
+ encoder_hidden_states=encoder_hidden_states,
471
+ cross_attention_kwargs=cross_attention_kwargs
472
+ ).sample
473
+ hidden_states = resnet(hidden_states, temb)
474
+
475
+ return hidden_states
476
+
477
+
478
+ class FFSpatioTempCrossAttnDownBlock3D(nn.Module):
479
+ def __init__(
480
+ self,
481
+ in_channels: int,
482
+ out_channels: int,
483
+ temb_channels: int,
484
+ dropout: float = 0.0,
485
+ num_layers: int = 1,
486
+ resnet_eps: float = 1e-6,
487
+ resnet_time_scale_shift: str = "default",
488
+ resnet_act_fn: str = "swish",
489
+ resnet_groups: int = 32,
490
+ resnet_pre_norm: bool = True,
491
+ attn_num_head_channels=1,
492
+ cross_attention_dim=1280,
493
+ output_scale_factor=1.0,
494
+ downsample_padding=1,
495
+ add_downsample=True,
496
+ dual_cross_attention=False,
497
+ use_linear_projection=False,
498
+ only_cross_attention=False,
499
+ upcast_attention=False,
500
+
501
+ ):
502
+ super().__init__()
503
+ resnets = []
504
+ attentions = []
505
+
506
+ self.has_cross_attention = True
507
+ self.attn_num_head_channels = attn_num_head_channels
508
+
509
+ for i in range(num_layers):
510
+ in_channels = in_channels if i == 0 else out_channels
511
+ resnets.append(
512
+ FFSpatioTempResnetBlock3D(
513
+ in_channels=in_channels,
514
+ out_channels=out_channels,
515
+ temb_channels=temb_channels,
516
+ eps=resnet_eps,
517
+ groups=resnet_groups,
518
+ dropout=dropout,
519
+ time_embedding_norm=resnet_time_scale_shift,
520
+ non_linearity=resnet_act_fn,
521
+ output_scale_factor=output_scale_factor,
522
+ pre_norm=resnet_pre_norm,
523
+
524
+ )
525
+ )
526
+ if dual_cross_attention:
527
+ raise NotImplementedError
528
+ attentions.append(
529
+ FFSpatioTempTransformer3DModel(
530
+ attn_num_head_channels,
531
+ out_channels // attn_num_head_channels,
532
+ in_channels=out_channels,
533
+ num_layers=1,
534
+ cross_attention_dim=cross_attention_dim,
535
+ norm_num_groups=resnet_groups,
536
+ use_linear_projection=use_linear_projection,
537
+ only_cross_attention=only_cross_attention,
538
+ upcast_attention=upcast_attention,
539
+ )
540
+ )
541
+ self.attentions = nn.ModuleList(attentions)
542
+ self.resnets = nn.ModuleList(resnets)
543
+
544
+ if add_downsample:
545
+ self.downsamplers = nn.ModuleList(
546
+ [
547
+ FFSpatioTempResDownsample3D(
548
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op",
549
+
550
+ )
551
+ ]
552
+ )
553
+ else:
554
+ self.downsamplers = None
555
+
556
+ self.gradient_checkpointing = False
557
+
558
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
559
+ cross_attention_kwargs=None):
560
+ output_states = ()
561
+
562
+ for resnet, attn in zip(self.resnets, self.attentions):
563
+ if self.training and self.gradient_checkpointing:
564
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
565
+ hidden_states = torch.utils.checkpoint.checkpoint(
566
+ create_custom_forward(attn, return_dict=False),
567
+ hidden_states,
568
+ encoder_hidden_states,
569
+ cross_attention_kwargs
570
+ )[0]
571
+ else:
572
+ hidden_states = resnet(hidden_states, temb)
573
+ hidden_states = attn(
574
+ hidden_states,
575
+ encoder_hidden_states=encoder_hidden_states,
576
+ cross_attention_kwargs=cross_attention_kwargs,
577
+ ).sample
578
+
579
+ output_states += (hidden_states,)
580
+
581
+ if self.downsamplers is not None:
582
+ for downsampler in self.downsamplers:
583
+ hidden_states = downsampler(hidden_states)
584
+
585
+ output_states += (hidden_states,)
586
+
587
+ return hidden_states, output_states
588
+
589
+
590
+ class FFSpatioTempCrossAttnUpBlock3D(nn.Module):
591
+ def __init__(
592
+ self,
593
+ in_channels: int,
594
+ out_channels: int,
595
+ prev_output_channel: int,
596
+ temb_channels: int,
597
+ dropout: float = 0.0,
598
+ num_layers: int = 1,
599
+ resnet_eps: float = 1e-6,
600
+ resnet_time_scale_shift: str = "default",
601
+ resnet_act_fn: str = "swish",
602
+ resnet_groups: int = 32,
603
+ resnet_pre_norm: bool = True,
604
+ attn_num_head_channels=1,
605
+ cross_attention_dim=1280,
606
+ output_scale_factor=1.0,
607
+ add_upsample=True,
608
+ dual_cross_attention=False,
609
+ use_linear_projection=False,
610
+ only_cross_attention=False,
611
+ upcast_attention=False,
612
+
613
+ ):
614
+ super().__init__()
615
+ resnets = []
616
+ attentions = []
617
+
618
+ self.has_cross_attention = True
619
+ self.attn_num_head_channels = attn_num_head_channels
620
+
621
+ for i in range(num_layers):
622
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
623
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
624
+
625
+ resnets.append(
626
+ FFSpatioTempResnetBlock3D(
627
+ in_channels=resnet_in_channels + res_skip_channels,
628
+ out_channels=out_channels,
629
+ temb_channels=temb_channels,
630
+ eps=resnet_eps,
631
+ groups=resnet_groups,
632
+ dropout=dropout,
633
+ time_embedding_norm=resnet_time_scale_shift,
634
+ non_linearity=resnet_act_fn,
635
+ output_scale_factor=output_scale_factor,
636
+ pre_norm=resnet_pre_norm,
637
+
638
+ )
639
+ )
640
+ if dual_cross_attention:
641
+ raise NotImplementedError
642
+ attentions.append(
643
+ FFSpatioTempTransformer3DModel(
644
+ attn_num_head_channels,
645
+ out_channels // attn_num_head_channels,
646
+ in_channels=out_channels,
647
+ num_layers=1,
648
+ cross_attention_dim=cross_attention_dim,
649
+ norm_num_groups=resnet_groups,
650
+ use_linear_projection=use_linear_projection,
651
+ only_cross_attention=only_cross_attention,
652
+ upcast_attention=upcast_attention,
653
+ )
654
+ )
655
+
656
+ self.attentions = nn.ModuleList(attentions)
657
+ self.resnets = nn.ModuleList(resnets)
658
+
659
+ if add_upsample:
660
+ self.upsamplers = nn.ModuleList(
661
+ [FFSpatioTempResUpsample3D(out_channels, use_conv=True, out_channels=out_channels,
662
+ )])
663
+ else:
664
+ self.upsamplers = None
665
+
666
+ self.gradient_checkpointing = False
667
+
668
+ def forward(
669
+ self,
670
+ hidden_states,
671
+ res_hidden_states_tuple,
672
+ temb=None,
673
+ encoder_hidden_states=None,
674
+ upsample_size=None,
675
+ attention_mask=None,
676
+ cross_attention_kwargs=None
677
+ ):
678
+ for resnet, attn in zip(self.resnets, self.attentions):
679
+ # pop res hidden states
680
+ res_hidden_states = res_hidden_states_tuple[-1]
681
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
682
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
683
+
684
+ if self.training and self.gradient_checkpointing:
685
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
686
+ hidden_states = torch.utils.checkpoint.checkpoint(
687
+ create_custom_forward(attn, return_dict=False),
688
+ hidden_states,
689
+ encoder_hidden_states,
690
+ cross_attention_kwargs
691
+ )[0]
692
+ else:
693
+ hidden_states = resnet(hidden_states, temb)
694
+ hidden_states = attn(
695
+ hidden_states,
696
+ encoder_hidden_states=encoder_hidden_states,
697
+ cross_attention_kwargs=cross_attention_kwargs,
698
+ ).sample
699
+
700
+ if self.upsamplers is not None:
701
+ for upsampler in self.upsamplers:
702
+ hidden_states = upsampler(hidden_states, upsample_size)
703
+
704
+ return hidden_states
705
+
706
+
707
+ ##### Audio Condition Blocks #####
708
+
709
+ class FFSpatioAudioTempCrossAttnUNetMidBlock3D(nn.Module):
710
+ def __init__(
711
+ self,
712
+ in_channels: int,
713
+ temb_channels: int,
714
+ dropout: float = 0.0,
715
+ num_layers: int = 1,
716
+ resnet_eps: float = 1e-6,
717
+ resnet_time_scale_shift: str = "default",
718
+ resnet_act_fn: str = "swish",
719
+ resnet_groups: int = 32,
720
+ resnet_pre_norm: bool = True,
721
+ attn_num_head_channels=1,
722
+ output_scale_factor=1.0,
723
+ cross_attention_dim=1280,
724
+ audio_cross_attention_dim=768,
725
+ dual_cross_attention=False,
726
+ use_linear_projection=False,
727
+ upcast_attention=False,
728
+
729
+ ):
730
+ super().__init__()
731
+
732
+ self.has_cross_attention = True
733
+ self.attn_num_head_channels = attn_num_head_channels
734
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
735
+
736
+ # there is always at least one resnet
737
+ resnets = [
738
+ FFSpatioTempResnetBlock3D(
739
+ in_channels=in_channels,
740
+ out_channels=in_channels,
741
+ temb_channels=temb_channels,
742
+ eps=resnet_eps,
743
+ groups=resnet_groups,
744
+ dropout=dropout,
745
+ time_embedding_norm=resnet_time_scale_shift,
746
+ non_linearity=resnet_act_fn,
747
+ output_scale_factor=output_scale_factor,
748
+ pre_norm=resnet_pre_norm,
749
+
750
+ )
751
+ ]
752
+ attentions = []
753
+
754
+ for _ in range(num_layers):
755
+ if dual_cross_attention:
756
+ raise NotImplementedError
757
+ attentions.append(
758
+ FFSpatioAudioTempTransformer3DModel(
759
+ attn_num_head_channels,
760
+ in_channels // attn_num_head_channels,
761
+ in_channels=in_channels,
762
+ num_layers=1,
763
+ cross_attention_dim=cross_attention_dim,
764
+ audio_cross_attention_dim=audio_cross_attention_dim,
765
+ norm_num_groups=resnet_groups,
766
+ use_linear_projection=use_linear_projection,
767
+ upcast_attention=upcast_attention,
768
+ )
769
+ )
770
+ resnets.append(
771
+ FFSpatioTempResnetBlock3D(
772
+ in_channels=in_channels,
773
+ out_channels=in_channels,
774
+ temb_channels=temb_channels,
775
+ eps=resnet_eps,
776
+ groups=resnet_groups,
777
+ dropout=dropout,
778
+ time_embedding_norm=resnet_time_scale_shift,
779
+ non_linearity=resnet_act_fn,
780
+ output_scale_factor=output_scale_factor,
781
+ pre_norm=resnet_pre_norm,
782
+
783
+ )
784
+ )
785
+
786
+ self.attentions = nn.ModuleList(attentions)
787
+ self.resnets = nn.ModuleList(resnets)
788
+
789
+ self.gradient_checkpointing = False
790
+
791
+ def forward(self, hidden_states, temb=None,
792
+ encoder_hidden_states=None, attention_mask=None,
793
+ audio_encoder_hidden_states=None, audio_attention_mask=None,
794
+ cross_attention_kwargs=None):
795
+ assert cross_attention_kwargs is None
796
+ if self.training and self.gradient_checkpointing:
797
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.resnets[0]), hidden_states,
798
+ temb)
799
+ else:
800
+ hidden_states = self.resnets[0](hidden_states, temb)
801
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
802
+ if self.training and self.gradient_checkpointing:
803
+ hidden_states = torch.utils.checkpoint.checkpoint(
804
+ create_custom_forward(attn, return_dict=False),
805
+ hidden_states,
806
+ encoder_hidden_states,
807
+ audio_encoder_hidden_states,
808
+ audio_attention_mask,
809
+ )[0]
810
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
811
+ else:
812
+ hidden_states = attn(
813
+ hidden_states,
814
+ encoder_hidden_states=encoder_hidden_states,
815
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
816
+ audio_attention_mask=audio_attention_mask,
817
+ cross_attention_kwargs=cross_attention_kwargs
818
+ ).sample
819
+ hidden_states = resnet(hidden_states, temb)
820
+
821
+ return hidden_states
822
+
823
+
824
+ class FFSpatioAudioTempCrossAttnDownBlock3D(nn.Module):
825
+ def __init__(
826
+ self,
827
+ in_channels: int,
828
+ out_channels: int,
829
+ temb_channels: int,
830
+ dropout: float = 0.0,
831
+ num_layers: int = 1,
832
+ resnet_eps: float = 1e-6,
833
+ resnet_time_scale_shift: str = "default",
834
+ resnet_act_fn: str = "swish",
835
+ resnet_groups: int = 32,
836
+ resnet_pre_norm: bool = True,
837
+ attn_num_head_channels=1,
838
+ cross_attention_dim=1280,
839
+ audio_cross_attention_dim=768,
840
+ output_scale_factor=1.0,
841
+ downsample_padding=1,
842
+ add_downsample=True,
843
+ dual_cross_attention=False,
844
+ use_linear_projection=False,
845
+ only_cross_attention=False,
846
+ upcast_attention=False,
847
+
848
+ ):
849
+ super().__init__()
850
+ resnets = []
851
+ attentions = []
852
+
853
+ self.has_cross_attention = True
854
+ self.attn_num_head_channels = attn_num_head_channels
855
+
856
+ for i in range(num_layers):
857
+ in_channels = in_channels if i == 0 else out_channels
858
+ resnets.append(
859
+ FFSpatioTempResnetBlock3D(
860
+ in_channels=in_channels,
861
+ out_channels=out_channels,
862
+ temb_channels=temb_channels,
863
+ eps=resnet_eps,
864
+ groups=resnet_groups,
865
+ dropout=dropout,
866
+ time_embedding_norm=resnet_time_scale_shift,
867
+ non_linearity=resnet_act_fn,
868
+ output_scale_factor=output_scale_factor,
869
+ pre_norm=resnet_pre_norm,
870
+
871
+ )
872
+ )
873
+ if dual_cross_attention:
874
+ raise NotImplementedError
875
+ attentions.append(
876
+ FFSpatioAudioTempTransformer3DModel(
877
+ attn_num_head_channels,
878
+ out_channels // attn_num_head_channels,
879
+ in_channels=out_channels,
880
+ num_layers=1,
881
+ cross_attention_dim=cross_attention_dim,
882
+ audio_cross_attention_dim=audio_cross_attention_dim,
883
+ norm_num_groups=resnet_groups,
884
+ use_linear_projection=use_linear_projection,
885
+ only_cross_attention=only_cross_attention,
886
+ upcast_attention=upcast_attention
887
+ )
888
+ )
889
+ self.attentions = nn.ModuleList(attentions)
890
+ self.resnets = nn.ModuleList(resnets)
891
+
892
+ if add_downsample:
893
+ self.downsamplers = nn.ModuleList(
894
+ [
895
+ FFSpatioTempResDownsample3D(
896
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op",
897
+
898
+ )
899
+ ]
900
+ )
901
+ else:
902
+ self.downsamplers = None
903
+
904
+ self.gradient_checkpointing = False
905
+
906
+ def forward(self, hidden_states, temb=None,
907
+ encoder_hidden_states=None, attention_mask=None,
908
+ audio_encoder_hidden_states=None, audio_attention_mask=None,
909
+ cross_attention_kwargs=None):
910
+ output_states = ()
911
+
912
+ for resnet, attn in zip(self.resnets, self.attentions):
913
+ if self.training and self.gradient_checkpointing:
914
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
915
+ hidden_states = torch.utils.checkpoint.checkpoint(
916
+ create_custom_forward(attn, return_dict=False),
917
+ hidden_states,
918
+ encoder_hidden_states,
919
+ audio_encoder_hidden_states,
920
+ audio_attention_mask
921
+ )[0]
922
+ else:
923
+ hidden_states = resnet(hidden_states, temb)
924
+ hidden_states = attn(
925
+ hidden_states,
926
+ encoder_hidden_states=encoder_hidden_states,
927
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
928
+ audio_attention_mask=audio_attention_mask,
929
+ cross_attention_kwargs=cross_attention_kwargs,
930
+ ).sample
931
+
932
+ output_states += (hidden_states,)
933
+
934
+ if self.downsamplers is not None:
935
+ for downsampler in self.downsamplers:
936
+ hidden_states = downsampler(hidden_states)
937
+
938
+ output_states += (hidden_states,)
939
+
940
+ return hidden_states, output_states
941
+
942
+
943
+ class FFSpatioAudioTempCrossAttnUpBlock3D(nn.Module):
944
+ def __init__(
945
+ self,
946
+ in_channels: int,
947
+ out_channels: int,
948
+ prev_output_channel: int,
949
+ temb_channels: int,
950
+ dropout: float = 0.0,
951
+ num_layers: int = 1,
952
+ resnet_eps: float = 1e-6,
953
+ resnet_time_scale_shift: str = "default",
954
+ resnet_act_fn: str = "swish",
955
+ resnet_groups: int = 32,
956
+ resnet_pre_norm: bool = True,
957
+ attn_num_head_channels=1,
958
+ cross_attention_dim=1280,
959
+ audio_cross_attention_dim=768,
960
+ output_scale_factor=1.0,
961
+ add_upsample=True,
962
+ dual_cross_attention=False,
963
+ use_linear_projection=False,
964
+ only_cross_attention=False,
965
+ upcast_attention=False,
966
+
967
+ ):
968
+ super().__init__()
969
+ resnets = []
970
+ attentions = []
971
+
972
+ self.has_cross_attention = True
973
+ self.attn_num_head_channels = attn_num_head_channels
974
+
975
+ for i in range(num_layers):
976
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
977
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
978
+
979
+ resnets.append(
980
+ FFSpatioTempResnetBlock3D(
981
+ in_channels=resnet_in_channels + res_skip_channels,
982
+ out_channels=out_channels,
983
+ temb_channels=temb_channels,
984
+ eps=resnet_eps,
985
+ groups=resnet_groups,
986
+ dropout=dropout,
987
+ time_embedding_norm=resnet_time_scale_shift,
988
+ non_linearity=resnet_act_fn,
989
+ output_scale_factor=output_scale_factor,
990
+ pre_norm=resnet_pre_norm,
991
+
992
+ )
993
+ )
994
+ if dual_cross_attention:
995
+ raise NotImplementedError
996
+ attentions.append(
997
+ FFSpatioAudioTempTransformer3DModel(
998
+ attn_num_head_channels,
999
+ out_channels // attn_num_head_channels,
1000
+ in_channels=out_channels,
1001
+ num_layers=1,
1002
+ cross_attention_dim=cross_attention_dim,
1003
+ audio_cross_attention_dim=audio_cross_attention_dim,
1004
+ norm_num_groups=resnet_groups,
1005
+ use_linear_projection=use_linear_projection,
1006
+ only_cross_attention=only_cross_attention,
1007
+ upcast_attention=upcast_attention,
1008
+ )
1009
+ )
1010
+
1011
+ self.attentions = nn.ModuleList(attentions)
1012
+ self.resnets = nn.ModuleList(resnets)
1013
+
1014
+ if add_upsample:
1015
+ self.upsamplers = nn.ModuleList(
1016
+ [FFSpatioTempResUpsample3D(out_channels, use_conv=True, out_channels=out_channels,
1017
+ )])
1018
+ else:
1019
+ self.upsamplers = None
1020
+
1021
+ self.gradient_checkpointing = False
1022
+
1023
+ def forward(
1024
+ self,
1025
+ hidden_states,
1026
+ res_hidden_states_tuple,
1027
+ temb=None,
1028
+ encoder_hidden_states=None,
1029
+ attention_mask=None,
1030
+ audio_encoder_hidden_states=None,
1031
+ audio_attention_mask=None,
1032
+ upsample_size=None,
1033
+ cross_attention_kwargs=None
1034
+ ):
1035
+ assert cross_attention_kwargs is None
1036
+ for resnet, attn in zip(self.resnets, self.attentions):
1037
+ # pop res hidden states
1038
+ res_hidden_states = res_hidden_states_tuple[-1]
1039
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1040
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1041
+
1042
+ if self.training and self.gradient_checkpointing:
1043
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1044
+ hidden_states = torch.utils.checkpoint.checkpoint(
1045
+ create_custom_forward(attn, return_dict=False),
1046
+ hidden_states,
1047
+ encoder_hidden_states,
1048
+ audio_encoder_hidden_states,
1049
+ audio_attention_mask,
1050
+ cross_attention_kwargs
1051
+ )[0]
1052
+ else:
1053
+ hidden_states = resnet(hidden_states, temb)
1054
+ hidden_states = attn(
1055
+ hidden_states,
1056
+ encoder_hidden_states=encoder_hidden_states,
1057
+ audio_encoder_hidden_states=audio_encoder_hidden_states,
1058
+ audio_attention_mask=audio_attention_mask,
1059
+ cross_attention_kwargs=cross_attention_kwargs,
1060
+ ).sample
1061
+
1062
+ if self.upsamplers is not None:
1063
+ for upsampler in self.upsamplers:
1064
+ hidden_states = upsampler(hidden_states, upsample_size)
1065
+
1066
+ return hidden_states
1067
+
1068
+
1069
+ all_modules = [
1070
+ ##### Image Condition #####
1071
+
1072
+ FFSpatioTempResDownBlock3D,
1073
+ FFSpatioTempResUpBlock3D,
1074
+
1075
+ FFSpatioTempCrossAttnUNetMidBlock3D,
1076
+ FFSpatioTempCrossAttnDownBlock3D,
1077
+ FFSpatioTempCrossAttnUpBlock3D,
1078
+
1079
+ ##### Audio Condition #####
1080
+
1081
+ FFSpatioAudioTempCrossAttnUNetMidBlock3D,
1082
+ FFSpatioAudioTempCrossAttnDownBlock3D,
1083
+ FFSpatioAudioTempCrossAttnUpBlock3D,
1084
+ ]
unet_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from einops import rearrange
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.models.attention import Attention
9
+
10
+
11
+ class InflatedConv3d(nn.Conv2d):
12
+ def forward(self, x):
13
+ video_length = x.shape[2]
14
+
15
+ x = rearrange(x, "b c f h w -> (b f) c h w")
16
+ x = super().forward(x)
17
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
18
+
19
+ return x
20
+
21
+
22
+ class FFInflatedConv3d(nn.Conv2d):
23
+ def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
24
+ super().__init__(
25
+ in_channels=in_channels,
26
+ out_channels=out_channels,
27
+ kernel_size=kernel_size,
28
+ **kwargs,
29
+ )
30
+ self.conv_temp = nn.Linear(3 * out_channels, out_channels)
31
+ nn.init.zeros_(self.conv_temp.weight.data) # initialized to be ones
32
+ nn.init.zeros_(self.conv_temp.bias.data)
33
+
34
+ def forward(self, x):
35
+ video_length = x.shape[2]
36
+
37
+ x = rearrange(x, "b c f h w -> (b f) c h w")
38
+ x = super().forward(x)
39
+
40
+ *_, h, w = x.shape
41
+ x = rearrange(x, "(b f) c h w -> (b h w) f c", f=video_length)
42
+
43
+ head_frame_index = [0, ] * video_length
44
+ prev_frame_index = torch.clamp(
45
+ torch.arange(video_length) - 1, min=0.0
46
+ ).long()
47
+ curr_frame_index = torch.arange(video_length).long()
48
+ conv_temp_nn_input = torch.cat([
49
+ x[:, head_frame_index],
50
+ x[:, prev_frame_index],
51
+ x[:, curr_frame_index]
52
+ ], dim=2).contiguous()
53
+ x = x + self.conv_temp(conv_temp_nn_input)
54
+
55
+ x = rearrange(x, "(b h w) f c -> b c f h w", h=h, w=w)
56
+
57
+ return x
58
+
59
+
60
+ class FFAttention(Attention):
61
+ r"""
62
+ A cross attention layer.
63
+
64
+ Parameters:
65
+ query_dim (`int`): The number of channels in the query.
66
+ cross_attention_dim (`int`, *optional*):
67
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
68
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
69
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
70
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
71
+ bias (`bool`, *optional*, defaults to False):
72
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ *args,
78
+ scale_qk: bool = True,
79
+ processor: Optional["FFAttnProcessor"] = None,
80
+ **kwargs
81
+ ):
82
+ super().__init__(*args, scale_qk=scale_qk, processor=processor, **kwargs)
83
+ # set attention processor
84
+ # We use the AttnProcessor by default when torch 2.x is used which uses
85
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
86
+ # but only if it has the default `scale` argument.
87
+ if processor is None:
88
+ processor = FFAttnProcessor()
89
+ self.set_processor(processor)
90
+
91
+ def forward(self, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None,
92
+ **cross_attention_kwargs):
93
+ # The `Attention` class can call different attention processors / attention functions
94
+ # here we simply pass along all tensors to the selected processor class
95
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
96
+ return self.processor(
97
+ self,
98
+ hidden_states,
99
+ encoder_hidden_states=encoder_hidden_states,
100
+ attention_mask=attention_mask,
101
+ video_length=video_length,
102
+ **cross_attention_kwargs,
103
+ )
104
+
105
+
106
+ class FFAttnProcessor:
107
+ def __init__(self):
108
+ if not hasattr(F, "scaled_dot_product_attention"):
109
+ raise ImportError(
110
+ "FFAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
111
+
112
+ def __call__(self, attn: Attention, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None):
113
+ batch_size, sequence_length, _ = (
114
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
115
+ )
116
+ inner_dim = hidden_states.shape[-1]
117
+
118
+ if attention_mask is not None:
119
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
120
+ # scaled_dot_product_attention expects attention_mask shape to be
121
+ # (batch, heads, source_length, target_length)
122
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
123
+
124
+ query = attn.to_q(hidden_states)
125
+
126
+ if encoder_hidden_states is None:
127
+ encoder_hidden_states = hidden_states
128
+ elif attn.norm_cross:
129
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
130
+
131
+ key = attn.to_k(encoder_hidden_states)
132
+ value = attn.to_v(encoder_hidden_states)
133
+
134
+ # sparse causal attention
135
+ former_frame_index = torch.arange(video_length) - 1
136
+ former_frame_index[0] = 0
137
+
138
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
139
+ key = key[:, [0] * video_length].contiguous()
140
+ key = rearrange(key, "b f d c -> (b f) d c")
141
+
142
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
143
+ value = value[:, [0] * video_length].contiguous()
144
+ value = rearrange(value, "b f d c -> (b f) d c")
145
+
146
+ head_dim = inner_dim // attn.heads
147
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
148
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
149
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
150
+
151
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
152
+ hidden_states = F.scaled_dot_product_attention(
153
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
154
+ )
155
+
156
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
157
+ hidden_states = hidden_states.to(query.dtype)
158
+
159
+ # linear proj
160
+ hidden_states = attn.to_out[0](hidden_states)
161
+ # dropout
162
+ hidden_states = attn.to_out[1](hidden_states)
163
+ return hidden_states