|
|
|
|
|
from .whisper import load_model
|
|
import numpy as np
|
|
import torch
|
|
import os
|
|
|
|
|
|
class Audio2Feature:
|
|
def __init__(
|
|
self,
|
|
model_path="checkpoints/whisper/tiny.pt",
|
|
device=None,
|
|
audio_embeds_cache_dir=None,
|
|
num_frames=16,
|
|
):
|
|
self.model = load_model(model_path, device)
|
|
self.audio_embeds_cache_dir = audio_embeds_cache_dir
|
|
self.num_frames = num_frames
|
|
self.embedding_dim = self.model.dims.n_audio_state
|
|
|
|
def get_sliced_feature(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
|
|
"""
|
|
Get sliced features based on a given index
|
|
:param feature_array:
|
|
:param start_idx: the start index of the feature
|
|
:param audio_feat_length:
|
|
:return:
|
|
"""
|
|
length = len(feature_array)
|
|
selected_feature = []
|
|
selected_idx = []
|
|
|
|
center_idx = int(vid_idx * 50 / fps)
|
|
left_idx = center_idx - audio_feat_length[0] * 2
|
|
right_idx = center_idx + (audio_feat_length[1] + 1) * 2
|
|
|
|
for idx in range(left_idx, right_idx):
|
|
idx = max(0, idx)
|
|
idx = min(length - 1, idx)
|
|
x = feature_array[idx]
|
|
selected_feature.append(x)
|
|
selected_idx.append(idx)
|
|
|
|
selected_feature = torch.cat(selected_feature, dim=0)
|
|
selected_feature = selected_feature.reshape(-1, self.embedding_dim)
|
|
return selected_feature, selected_idx
|
|
|
|
def get_sliced_feature_sparse(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
|
|
"""
|
|
Get sliced features based on a given index
|
|
:param feature_array:
|
|
:param start_idx: the start index of the feature
|
|
:param audio_feat_length:
|
|
:return:
|
|
"""
|
|
length = len(feature_array)
|
|
selected_feature = []
|
|
selected_idx = []
|
|
|
|
for dt in range(-audio_feat_length[0], audio_feat_length[1] + 1):
|
|
left_idx = int((vid_idx + dt) * 50 / fps)
|
|
if left_idx < 1 or left_idx > length - 1:
|
|
left_idx = max(0, left_idx)
|
|
left_idx = min(length - 1, left_idx)
|
|
|
|
x = feature_array[left_idx]
|
|
x = x[np.newaxis, :, :]
|
|
x = np.repeat(x, 2, axis=0)
|
|
selected_feature.append(x)
|
|
selected_idx.append(left_idx)
|
|
selected_idx.append(left_idx)
|
|
else:
|
|
x = feature_array[left_idx - 1 : left_idx + 1]
|
|
selected_feature.append(x)
|
|
selected_idx.append(left_idx - 1)
|
|
selected_idx.append(left_idx)
|
|
selected_feature = np.concatenate(selected_feature, axis=0)
|
|
selected_feature = selected_feature.reshape(-1, self.embedding_dim)
|
|
selected_feature = torch.from_numpy(selected_feature)
|
|
return selected_feature, selected_idx
|
|
|
|
def feature2chunks(self, feature_array, fps, audio_feat_length=[2, 2]):
|
|
whisper_chunks = []
|
|
whisper_idx_multiplier = 50.0 / fps
|
|
i = 0
|
|
print(f"video in {fps} FPS, audio idx in 50FPS")
|
|
|
|
while True:
|
|
start_idx = int(i * whisper_idx_multiplier)
|
|
selected_feature, selected_idx = self.get_sliced_feature(
|
|
feature_array=feature_array, vid_idx=i, audio_feat_length=audio_feat_length, fps=fps
|
|
)
|
|
|
|
whisper_chunks.append(selected_feature)
|
|
i += 1
|
|
if start_idx > len(feature_array):
|
|
break
|
|
|
|
return whisper_chunks
|
|
|
|
def _audio2feat(self, audio_path: str):
|
|
|
|
result = self.model.transcribe(audio_path)
|
|
embed_list = []
|
|
for emb in result["segments"]:
|
|
encoder_embeddings = emb["encoder_embeddings"]
|
|
encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
|
|
encoder_embeddings = encoder_embeddings.squeeze(0)
|
|
start_idx = int(emb["start"])
|
|
end_idx = int(emb["end"])
|
|
emb_end_idx = int((end_idx - start_idx) / 2)
|
|
embed_list.append(encoder_embeddings[:emb_end_idx])
|
|
concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
|
|
return concatenated_array
|
|
|
|
def audio2feat(self, audio_path):
|
|
if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
|
|
return self._audio2feat(audio_path)
|
|
|
|
audio_embeds_cache_path = os.path.join(self.audio_embeds_cache_dir, os.path.basename(audio_path) + ".pt")
|
|
|
|
if os.path.isfile(audio_embeds_cache_path):
|
|
try:
|
|
audio_feat = torch.load(audio_embeds_cache_path)
|
|
except Exception as e:
|
|
print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
|
|
os.remove(audio_embeds_cache_path)
|
|
audio_feat = self._audio2feat(audio_path)
|
|
torch.save(audio_feat, audio_embeds_cache_path)
|
|
else:
|
|
audio_feat = self._audio2feat(audio_path)
|
|
torch.save(audio_feat, audio_embeds_cache_path)
|
|
|
|
return audio_feat
|
|
|
|
def crop_overlap_audio_window(self, audio_feat, start_index):
|
|
selected_feature_list = []
|
|
for i in range(start_index, start_index + self.num_frames):
|
|
selected_feature, selected_idx = self.get_sliced_feature(
|
|
feature_array=audio_feat, vid_idx=i, audio_feat_length=[2, 2], fps=25
|
|
)
|
|
selected_feature_list.append(selected_feature)
|
|
mel_overlap = torch.stack(selected_feature_list)
|
|
return mel_overlap
|
|
|
|
|
|
if __name__ == "__main__":
|
|
audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
|
|
audio_path = "assets/demo1_audio.wav"
|
|
array = audio_encoder.audio2feat(audio_path)
|
|
print(array.shape)
|
|
fps = 25
|
|
whisper_idx_multiplier = 50.0 / fps
|
|
|
|
i = 0
|
|
print(f"video in {fps} FPS, audio idx in 50FPS")
|
|
while True:
|
|
start_idx = int(i * whisper_idx_multiplier)
|
|
selected_feature, selected_idx = audio_encoder.get_sliced_feature(
|
|
feature_array=array, vid_idx=i, audio_feat_length=[2, 2], fps=fps
|
|
)
|
|
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
|
|
i += 1
|
|
if start_idx > len(array):
|
|
break
|
|
|