File size: 3,135 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19bda5d
373af33
 
 
 
 
 
 
 
 
 
 
 
 
19bda5d
 
 
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from imagebind import data
import torch
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import os
import numpy as np
from tqdm import tqdm
import json
import pickle




class FeatureExtractor(imagebind_model.ImageBindModel):

    def forward(self, inputs):
        outputs = {}
        for modality_key, modality_value in inputs.items():
            reduce_list = (
                modality_value.ndim >= 5
            )  # Audio and Video inputs consist of multiple clips
            if reduce_list:
                B, S = modality_value.shape[:2]
                modality_value = modality_value.reshape(
                    B * S, *modality_value.shape[2:]
                )

            if modality_value is not None:
                modality_value = self.modality_preprocessors[modality_key](
                    **{modality_key: modality_value}
                )
                trunk_inputs = modality_value["trunk"]
                head_inputs = modality_value["head"]
                modality_value = self.modality_trunks[modality_key](**trunk_inputs)
                word_feat = modality_value
                seq_feat = self.modality_heads[modality_key](
                    word_feat, **head_inputs
                )
                seq_feat = self.modality_postprocessors[modality_key](
                    seq_feat
                )
        return word_feat, seq_feat
    

def imagebind_huge(pretrained=False, ckpt_path=None):
    model = FeatureExtractor(
        vision_embed_dim=1280,
        vision_num_blocks=32,
        vision_num_heads=16,
        text_embed_dim=1024,
        text_num_blocks=24,
        text_num_heads=16,
        out_embed_dim=1024,
        audio_drop_path=0.1,
        imu_drop_path=0.7,
    )

    if pretrained:
        # file_path = os.path.abspath(os.path.dirname(__file__))
        # ckpt_dir = os.path.join(file_path, '../../../data/motionverse/pretrained')
        # ckpt_path = os.path.join(ckpt_dir, 'imagebind_huge.pth')
        if not os.path.exists(ckpt_path):
            print(
                "Downloading imagebind weights to motionverse/pretrained/imagebind_huge.pth ..."
            )
            os.makedirs(ckpt_dir, exist_ok=True)
            torch.hub.download_url_to_file(
                "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
                ckpt_path,
                progress=True,
            )

        model.load_state_dict(torch.load(ckpt_path))
    return model


def extract_text_feature(text, model, device):
    text_list = text
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    }
    with torch.no_grad():
        text_word_feat, text_seq_feat = model(inputs)
    return text_word_feat, text_seq_feat


def extract_audio_feature(audio_paths, model, device):
    inputs = {
        ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device)
    }
    inputs['audio'] = inputs['audio'][:, :1]
    with torch.no_grad():
        audio_word_feat, audio_seq_feat = model(inputs)
    return audio_word_feat, audio_seq_feat