hainazhu
Add application file
258fd02
import os, sys
from transformers import AutoModel
import torch
from torch import nn
import torchaudio.transforms as T
import einops
import numpy as np
import joblib
from torch.nn.utils.rnn import pad_sequence
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = lengths.max()
n = lengths.size(0)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1)
class KmeansQuantizer(nn.Module):
def __init__(self, centroids) -> None:
super().__init__()
if type(centroids) == np.ndarray:
centroids = torch.from_numpy(centroids)
# self.clusters = nn.Embedding(n_cluster, feature_dim)
self.clusters = nn.Parameter(centroids)
@classmethod
def from_pretrained(cls, km_path):
km_model = joblib.load(km_path)
centroids = km_model.cluster_centers_
return cls(centroids)
@property
def n_cluster(self) -> int:
return self.clusters.shape[0]
@property
def feature_dim(self) -> int:
return self.clusters.shape[1]
def forward(self, inp: torch.Tensor):
if inp.ndim == 3 and inp.shape[-1] == self.feature_dim:
return self.feat2indice(inp)
elif inp.ndim < 3:
return self.indice2feat(inp)
else:
raise NotImplementedError
def feat2indice(self, feat):
'''
feat: B,T,D
'''
batched_cluster_centers = einops.repeat(self.clusters, 'c d -> b c d', b = feat.shape[0])
dists = torch.cdist(feat, batched_cluster_centers, p = 2)
indices = dists.argmin(dim = -1)
return indices
def indice2feat(self, indice):
'''
indice: B, T
'''
return nn.functional.embedding(input=indice, weight=self.clusters)
class MERTwithKmeans(nn.Module):
def __init__(self, pretrained_model_name_or_path, kmeans_path=None, sampling_rate=44100, output_layer=-1, mean_pool=1) -> None:
super().__init__()
# assert pretrained_model_name_or_path in ["MERT-v1-95M", "MERT-v1-330M"]
assert pretrained_model_name_or_path == "MERT-v1-330M"
# loading our model weights
# self.model = AutoModel.from_pretrained(f"vocal2accmpl/model/.cache/models--m-a-p--MERT-v1-95M/snapshots/8881df140a93e2ea270235b5d7be802245e3d2c6", trust_remote_code=True)
self.model = AutoModel.from_pretrained('pretrained/models--m-a-p--MERT-v1-330M/snapshots/af10da70c94a0c849de9cc94b83e12769c4db499', trust_remote_code=True)
# print(self.model)
if kmeans_path is not None:
centroids = joblib.load(kmeans_path).cluster_centers_
self.kmeans = KmeansQuantizer(centroids)
else:
self.kmeans = None
# loading the corresponding preprocessor config
# self.processor = Wav2Vec2FeatureExtractor.from_pretrained(f"m-a-p/{pretrained_model_name_or_path}",trust_remote_code=True)
# make sure the sample_rate aligned
self.sampling_rate = sampling_rate
self.resampler = T.Resample(sampling_rate, 24000) if sampling_rate != 24000 else lambda x: x
self.do_normalization = (pretrained_model_name_or_path == "MERT-v1-95M")
self.output_layer = output_layer
self.mean_pool = mean_pool
assert self.mean_pool % 2 == 1
@torch.no_grad()
def forward(self, input_audio, seq_len=None, apply_kmeans=True):
'''
input_audio: B,T
seq_len: B,
'''
device = input_audio.device
return_seq_len = True
if seq_len is None:
return_seq_len = False
seq_len = [input_audio.shape[1] for _ in input_audio]
input_audio = [self.resampler(x[:l]) for x, l in zip(input_audio, seq_len)]
new_seq_len = torch.tensor([len(i) for i in input_audio], device=device)
# std_inp = self.processor([x.numpy() for x in input_audio], sampling_rate=24000, return_tensors="pt", padding=True)
if self.do_normalization:
input_audio = self.zero_mean_unit_var_norm(input_audio, new_seq_len)
padded_input = pad_sequence(input_audio, batch_first=True)
attention_mask = ~ make_pad_mask(new_seq_len)
# assert (~(attention_mask == std_inp['attention_mask'])).sum() == 0, f"{attention_mask}, {std_inp['attention_mask']}"
# assert (~(padded_input.to(dtype=std_inp['input_values'].dtype) == std_inp['input_values'])).sum() == 0, f"{torch.sum((padded_input - std_inp['input_values']))}"
outputs = self.model(input_values=padded_input, attention_mask=attention_mask, output_hidden_states=True)
output = outputs['hidden_states'][self.output_layer]
output_len = torch.round(new_seq_len.float() / 24000 * 75).long()
# print(output_len)
# output_len = output_len.masked_fill(output_len > output.shape[1], output.shape[1]).long()
output = nn.functional.interpolate(output.transpose(-1,-2), output_len.max().item()).transpose(-1,-2)
if self.mean_pool > 1:
output_len = output_len // 3
output = nn.functional.avg_pool1d(output.transpose(-1, -2), kernel_size=self.mean_pool, stride=self.mean_pool)
output = output.transpose(-1,-2)
# print(output.shape, output_len)
# print(output.shape, output_len)
if apply_kmeans:
output = self.kmeans.feat2indice(output)
if return_seq_len:
return output, output_len
return output
# from transformers.models.wav2vec2.feature_extraction_wav2vec2
# rewrite it by pytorch
@staticmethod
def zero_mean_unit_var_norm(
input_values: torch.Tensor, seq_len: torch.Tensor = None, padding_value: float = 0.0
) -> torch.Tensor:
"""
Every array in the list is normalized to have zero mean and unit variance
"""
if seq_len is not None:
normed_input_values = []
for vector, length in zip(input_values, seq_len):
normed_slice = (vector - vector[:length].mean()) / torch.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
# normed_input_values = torch.stack(normed_input_values, dim=0)
else:
normed_input_values = (input_values - input_values.mean(dim=-1, keepdim=True)) / torch.sqrt(input_values.var(dim=-1, keepdim=True) + 1e-7)
return normed_input_values