File size: 3,135 Bytes
373af33 19bda5d 373af33 19bda5d 373af33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from imagebind import data
import torch
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import os
import numpy as np
from tqdm import tqdm
import json
import pickle
class FeatureExtractor(imagebind_model.ImageBindModel):
def forward(self, inputs):
outputs = {}
for modality_key, modality_value in inputs.items():
reduce_list = (
modality_value.ndim >= 5
) # Audio and Video inputs consist of multiple clips
if reduce_list:
B, S = modality_value.shape[:2]
modality_value = modality_value.reshape(
B * S, *modality_value.shape[2:]
)
if modality_value is not None:
modality_value = self.modality_preprocessors[modality_key](
**{modality_key: modality_value}
)
trunk_inputs = modality_value["trunk"]
head_inputs = modality_value["head"]
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
word_feat = modality_value
seq_feat = self.modality_heads[modality_key](
word_feat, **head_inputs
)
seq_feat = self.modality_postprocessors[modality_key](
seq_feat
)
return word_feat, seq_feat
def imagebind_huge(pretrained=False, ckpt_path=None):
model = FeatureExtractor(
vision_embed_dim=1280,
vision_num_blocks=32,
vision_num_heads=16,
text_embed_dim=1024,
text_num_blocks=24,
text_num_heads=16,
out_embed_dim=1024,
audio_drop_path=0.1,
imu_drop_path=0.7,
)
if pretrained:
# file_path = os.path.abspath(os.path.dirname(__file__))
# ckpt_dir = os.path.join(file_path, '../../../data/motionverse/pretrained')
# ckpt_path = os.path.join(ckpt_dir, 'imagebind_huge.pth')
if not os.path.exists(ckpt_path):
print(
"Downloading imagebind weights to motionverse/pretrained/imagebind_huge.pth ..."
)
os.makedirs(ckpt_dir, exist_ok=True)
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
ckpt_path,
progress=True,
)
model.load_state_dict(torch.load(ckpt_path))
return model
def extract_text_feature(text, model, device):
text_list = text
inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
}
with torch.no_grad():
text_word_feat, text_seq_feat = model(inputs)
return text_word_feat, text_seq_feat
def extract_audio_feature(audio_paths, model, device):
inputs = {
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device)
}
inputs['audio'] = inputs['audio'][:, :1]
with torch.no_grad():
audio_word_feat, audio_seq_feat = model(inputs)
return audio_word_feat, audio_seq_feat
|