|
import logging
|
|
import math
|
|
import os
|
|
import subprocess
|
|
from io import BytesIO
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
from audio_separator.separator import Separator
|
|
from einops import rearrange
|
|
from funasr.download.download_from_hub import download_model
|
|
from funasr.models.emotion2vec.model import Emotion2vec
|
|
from transformers import Wav2Vec2FeatureExtractor
|
|
|
|
from memo.models.emotion_classifier import AudioEmotionClassifierModel
|
|
from memo.models.wav2vec import Wav2VecModel
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int = 16000):
|
|
p = subprocess.Popen(
|
|
[
|
|
"ffmpeg",
|
|
"-y",
|
|
"-v",
|
|
"error",
|
|
"-i",
|
|
input_audio_file,
|
|
"-ar",
|
|
str(sample_rate),
|
|
output_audio_file,
|
|
]
|
|
)
|
|
ret = p.wait()
|
|
assert ret == 0, f"Resample audio failed! Input: {input_audio_file}, Output: {output_audio_file}"
|
|
return output_audio_file
|
|
|
|
|
|
@torch.no_grad()
|
|
def preprocess_audio(
|
|
wav_path: str,
|
|
fps: int,
|
|
wav2vec_model: str,
|
|
vocal_separator_model: str = None,
|
|
cache_dir: str = "",
|
|
device: str = "cuda",
|
|
sample_rate: int = 16000,
|
|
num_generated_frames_per_clip: int = -1,
|
|
):
|
|
"""
|
|
Preprocess the audio file and extract audio embeddings.
|
|
|
|
Args:
|
|
wav_path (str): Path to the input audio file.
|
|
fps (int): Frames per second for the audio processing.
|
|
wav2vec_model (str): Path to the pretrained Wav2Vec model.
|
|
vocal_separator_model (str, optional): Path to the vocal separator model. Defaults to None.
|
|
cache_dir (str, optional): Directory for cached files. Defaults to "".
|
|
device (str, optional): Device to use ('cuda' or 'cpu'). Defaults to "cuda".
|
|
sample_rate (int, optional): Sampling rate for audio processing. Defaults to 16000.
|
|
num_generated_frames_per_clip (int, optional): Number of generated frames per clip for padding. Defaults to -1.
|
|
|
|
Returns:
|
|
tuple: A tuple containing:
|
|
- audio_emb (torch.Tensor): The processed audio embeddings.
|
|
- audio_length (int): The length of the audio in frames.
|
|
"""
|
|
|
|
audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model).to(device=device)
|
|
audio_encoder.feature_extractor._freeze_parameters()
|
|
|
|
|
|
wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model)
|
|
|
|
|
|
vocal_separator = None
|
|
if vocal_separator_model is not None:
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
vocal_separator = Separator(
|
|
output_dir=cache_dir,
|
|
output_single_stem="vocals",
|
|
model_file_dir=os.path.dirname(vocal_separator_model),
|
|
)
|
|
vocal_separator.load_model(os.path.basename(vocal_separator_model))
|
|
assert vocal_separator.model_instance is not None, "Failed to load audio separation model."
|
|
|
|
|
|
if vocal_separator is not None:
|
|
outputs = vocal_separator.separate(wav_path)
|
|
assert len(outputs) > 0, "Audio separation failed."
|
|
vocal_audio_file = outputs[0]
|
|
vocal_audio_name, _ = os.path.splitext(vocal_audio_file)
|
|
vocal_audio_file = os.path.join(vocal_separator.output_dir, vocal_audio_file)
|
|
vocal_audio_file = resample_audio(
|
|
vocal_audio_file,
|
|
os.path.join(vocal_separator.output_dir, f"{vocal_audio_name}-16k.wav"),
|
|
sample_rate,
|
|
)
|
|
else:
|
|
vocal_audio_file = wav_path
|
|
|
|
|
|
speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=sample_rate)
|
|
audio_feature = np.squeeze(wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
|
|
audio_length = math.ceil(len(audio_feature) / sample_rate * fps)
|
|
audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
|
|
|
|
|
|
if num_generated_frames_per_clip > 0 and audio_length % num_generated_frames_per_clip != 0:
|
|
audio_feature = torch.nn.functional.pad(
|
|
audio_feature,
|
|
(
|
|
0,
|
|
(num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip) * (sample_rate // fps),
|
|
),
|
|
"constant",
|
|
0.0,
|
|
)
|
|
audio_length += num_generated_frames_per_clip - audio_length % num_generated_frames_per_clip
|
|
audio_feature = audio_feature.unsqueeze(0)
|
|
|
|
|
|
with torch.no_grad():
|
|
embeddings = audio_encoder(audio_feature, seq_len=audio_length, output_hidden_states=True)
|
|
assert len(embeddings) > 0, "Failed to extract audio embeddings."
|
|
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
|
|
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
|
|
|
|
|
audio_emb = audio_emb.cpu().detach()
|
|
concatenated_tensors = []
|
|
for i in range(audio_emb.shape[0]):
|
|
vectors_to_concat = [audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)] for j in range(-2, 3)]
|
|
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0))
|
|
audio_emb = torch.stack(concatenated_tensors, dim=0)
|
|
|
|
if vocal_separator is not None:
|
|
del vocal_separator
|
|
del audio_encoder
|
|
|
|
return audio_emb, audio_length
|
|
|
|
|
|
@torch.no_grad()
|
|
def extract_audio_emotion_labels(
|
|
model: str,
|
|
wav_path: str,
|
|
emotion2vec_model: str,
|
|
audio_length: int,
|
|
sample_rate: int = 16000,
|
|
device: str = "cuda",
|
|
):
|
|
"""
|
|
Extract audio emotion labels from an audio file.
|
|
|
|
Args:
|
|
model (str): Path to the MEMO model.
|
|
wav_path (str): Path to the input audio file.
|
|
emotion2vec_model (str): Path to the Emotion2vec model.
|
|
audio_length (int): Target length for interpolated emotion labels.
|
|
sample_rate (int, optional): Sample rate of the input audio. Default is 16000.
|
|
device (str, optional): Device to use ('cuda' or 'cpu'). Default is "cuda".
|
|
|
|
Returns:
|
|
torch.Tensor: Processed emotion labels with shape matching the target audio length.
|
|
"""
|
|
|
|
logger.info("Downloading emotion2vec models from modelscope")
|
|
kwargs = download_model(model=emotion2vec_model)
|
|
kwargs["tokenizer"] = None
|
|
kwargs["input_size"] = None
|
|
kwargs["frontend"] = None
|
|
emotion_model = Emotion2vec(**kwargs, vocab_size=-1).to(device)
|
|
init_param = kwargs.get("init_param", None)
|
|
load_emotion2vec_model(
|
|
model=emotion_model,
|
|
path=init_param,
|
|
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
|
oss_bucket=kwargs.get("oss_bucket", None),
|
|
scope_map=kwargs.get("scope_map", []),
|
|
)
|
|
emotion_model.eval()
|
|
|
|
classifier = AudioEmotionClassifierModel.from_pretrained(
|
|
model,
|
|
subfolder="misc/audio_emotion_classifier",
|
|
use_safetensors=True,
|
|
).to(device=device)
|
|
classifier.eval()
|
|
|
|
|
|
wav, sr = torchaudio.load(wav_path)
|
|
if sr != sample_rate:
|
|
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
|
wav = wav.view(-1) if wav.dim() == 1 else wav[0].view(-1)
|
|
|
|
emotion_labels = torch.full_like(wav, -1, dtype=torch.int32)
|
|
|
|
def extract_emotion(x):
|
|
"""
|
|
Extract emotion for a given audio segment.
|
|
"""
|
|
x = x.to(device=device)
|
|
x = F.layer_norm(x, x.shape).view(1, -1)
|
|
feats = emotion_model.extract_features(x)
|
|
x = feats["x"].mean(dim=1)
|
|
x = classifier(x)
|
|
x = torch.softmax(x, dim=-1)
|
|
return torch.argmax(x, dim=-1)
|
|
|
|
|
|
start_label = extract_emotion(wav[: sample_rate * 2]).item()
|
|
emotion_labels[:sample_rate] = start_label
|
|
|
|
for i in range(sample_rate, len(wav) - sample_rate, sample_rate):
|
|
mid_wav = wav[i - sample_rate : i - sample_rate + sample_rate * 3]
|
|
mid_label = extract_emotion(mid_wav).item()
|
|
emotion_labels[i : i + sample_rate] = mid_label
|
|
|
|
end_label = extract_emotion(wav[-sample_rate * 2 :]).item()
|
|
emotion_labels[-sample_rate:] = end_label
|
|
|
|
|
|
emotion_labels = emotion_labels.unsqueeze(0).unsqueeze(0).float()
|
|
emotion_labels = F.interpolate(emotion_labels, size=audio_length, mode="nearest").squeeze(0).squeeze(0).int()
|
|
num_emotion_classes = classifier.num_emotion_classes
|
|
|
|
del emotion_model
|
|
del classifier
|
|
|
|
return emotion_labels, num_emotion_classes
|
|
|
|
|
|
def load_emotion2vec_model(
|
|
path: str,
|
|
model: torch.nn.Module,
|
|
ignore_init_mismatch: bool = True,
|
|
map_location: str = "cpu",
|
|
oss_bucket=None,
|
|
scope_map=[],
|
|
):
|
|
obj = model
|
|
dst_state = obj.state_dict()
|
|
logger.debug(f"Emotion2vec checkpoint: {path}")
|
|
if oss_bucket is None:
|
|
src_state = torch.load(path, map_location=map_location)
|
|
else:
|
|
buffer = BytesIO(oss_bucket.get_object(path).read())
|
|
src_state = torch.load(buffer, map_location=map_location)
|
|
|
|
src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
|
|
src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
|
|
src_state = src_state["model"] if "model" in src_state else src_state
|
|
|
|
if isinstance(scope_map, str):
|
|
scope_map = scope_map.split(",")
|
|
scope_map += ["module.", "None"]
|
|
|
|
for k in dst_state.keys():
|
|
k_src = k
|
|
if scope_map is not None:
|
|
src_prefix = ""
|
|
dst_prefix = ""
|
|
for i in range(0, len(scope_map), 2):
|
|
src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
|
|
dst_prefix = scope_map[i + 1] if scope_map[i + 1].lower() != "none" else ""
|
|
|
|
if dst_prefix == "" and (src_prefix + k) in src_state.keys():
|
|
k_src = src_prefix + k
|
|
if not k_src.startswith("module."):
|
|
logger.debug(f"init param, map: {k} from {k_src} in ckpt")
|
|
elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys():
|
|
k_src = k.replace(dst_prefix, src_prefix, 1)
|
|
if not k_src.startswith("module."):
|
|
logger.debug(f"init param, map: {k} from {k_src} in ckpt")
|
|
|
|
if k_src in src_state.keys():
|
|
if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
|
|
logger.debug(
|
|
f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
|
|
)
|
|
else:
|
|
dst_state[k] = src_state[k_src]
|
|
|
|
else:
|
|
logger.debug(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
|
|
|
|
obj.load_state_dict(dst_state, strict=True)
|
|
|